fix learning rate scheduler

This commit is contained in:
rasbt 2024-06-03 07:06:42 -05:00
parent 5adc6a8f69
commit 5a1e0eecce
2 changed files with 17 additions and 22 deletions

File diff suppressed because one or more lines are too long

View File

@ -65,13 +65,13 @@ def train_model(model, train_loader, val_loader, optimizer, device,
initial_lr=3e-05, min_lr=1e-6):
global_step = 0
max_lr = optimizer.defaults["lr"]
max_lr = optimizer.param_groups[0]["lr"]
# Calculate total number of iterations
total_training_iters = len(train_loader) * n_epochs
# Calculate the learning rate increment at each step during warmup
lr_increment = (optimizer.defaults["lr"] - initial_lr) / warmup_iters
lr_increment = (optimizer.param_groups[0]["lr"] - initial_lr) / warmup_iters
for epoch in range(n_epochs):
model.train()