Add experiment with gradient accumulation

This commit is contained in:
rasbt 2024-05-17 21:31:22 -05:00
parent 623bc19665
commit 10ebc47720
2 changed files with 49 additions and 19 deletions

View File

@ -21,7 +21,8 @@ For example,
| 8 | gpt2-small (124M) | random | last | all | longest train ex. (120) | 100% | 96.64% | 93.67% | 0.69 min | A100 |
| 9 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | 99.52% | 97.99% | 97.67% | 0.75 min | A100 |
| 10 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 |
| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding | 97.42% | 95.30% | 95.00% | 1.71 min | A100 |
| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 |
| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
 
@ -40,7 +41,8 @@ You can use the following code to reproduce the experiments:
- Row 8: `python additional-experiments.py --weights random --trainable_layers all`
- Row 9: `python additional-experiments.py --trainable_layers lora --lora_rank 16 --lora_alpha 8`
- Row 10: `python additional-experiments.py --context_length "model_context_length"`
- Row 11: `python additional-experiments.py --no_padding`
- Row 11: `python additional-experiments.py --no_padding --batch_size 1`
- Row 12: `python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes in case you don't have access to a GPU.
@ -62,4 +64,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
7. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 10)**: Padding the input to the full supported context length results is significantly worse.
8. **Padding vs no padding (Row 1 vs 11)**: The `--no_padding` option disables the padding in the dataset and trains the model with a batch size of 1 where the inputs have variable lengths. This results in exactly the same test set accuracy but takes substantially longer to train.
8. **Padding vs no padding (Row 1 vs. 11 and 12)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 12, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy.

View File

@ -222,7 +222,8 @@ def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1):
eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1,
accumulation_steps=1):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
@ -231,11 +232,21 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
for epoch in range(num_epochs):
model.train() # Set model to training mode
for input_batch, target_batch in train_loader:
optimizer.zero_grad() # Reset loss gradients from previous epoch
for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
# Use gradient accumulation if accumulation_steps > 1
# See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html
# for an explanation
loss /= accumulation_steps
loss.backward() # Calculate loss gradients
optimizer.step() # Update model weights using loss gradients
# Use gradient accumulation if accumulation_steps > 1
if batch_idx % accumulation_steps == 0:
optimizer.step() # Update model weights using loss gradients
optimizer.zero_grad() # Reset loss gradients from previous epoch
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
global_step += 1
@ -341,8 +352,8 @@ if __name__ == "__main__":
action='store_true',
default=False,
help=(
"Enable no padding. When this flag is set it will train"
" the model with a batch size of 1 and no padding."
"Disable padding, which means each example may have a different lenght."
" This requires setting `--batch_size 1`."
)
)
parser.add_argument(
@ -353,6 +364,27 @@ if __name__ == "__main__":
"Number of training epochs."
)
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help=(
"The batch size used for training."
)
)
parser.add_argument(
"--accumulation_steps",
type=int,
default=1,
help=(
"Accumulation steps to allow for gradient accumulation."
" See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html for explanation."
" For example, setting `batch_size=8` and `accumulation_steps=1` compute the exact same"
" loss and weight updates as setting `batch_size=1` and `accumulation_steps=8`, however,"
" the latter setting uses more iterations."
)
)
args = parser.parse_args()
@ -455,14 +487,9 @@ if __name__ == "__main__":
num_workers = 0
if args.no_padding:
batch_size = 1
else:
batch_size = 8
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
batch_size=args.batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
@ -470,14 +497,14 @@ if __name__ == "__main__":
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
batch_size=args.batch_size,
num_workers=num_workers,
drop_last=False,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
batch_size=args.batch_size,
num_workers=num_workers,
drop_last=False,
)
@ -493,7 +520,8 @@ if __name__ == "__main__":
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=args.num_epochs, eval_freq=50, eval_iter=5,
tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token
tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token,
accumulation_steps=args.accumulation_steps
)
end_time = time.time()
@ -510,4 +538,4 @@ if __name__ == "__main__":
print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")