mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-16 10:04:53 +00:00
fix learning rate scheduler
This commit is contained in:
parent
5adc6a8f69
commit
5a1e0eecce
File diff suppressed because one or more lines are too long
@ -65,13 +65,13 @@ def train_model(model, train_loader, val_loader, optimizer, device,
|
|||||||
initial_lr=3e-05, min_lr=1e-6):
|
initial_lr=3e-05, min_lr=1e-6):
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
max_lr = optimizer.defaults["lr"]
|
max_lr = optimizer.param_groups[0]["lr"]
|
||||||
|
|
||||||
# Calculate total number of iterations
|
# Calculate total number of iterations
|
||||||
total_training_iters = len(train_loader) * n_epochs
|
total_training_iters = len(train_loader) * n_epochs
|
||||||
|
|
||||||
# Calculate the learning rate increment at each step during warmup
|
# Calculate the learning rate increment at each step during warmup
|
||||||
lr_increment = (optimizer.defaults["lr"] - initial_lr) / warmup_iters
|
lr_increment = (optimizer.param_groups[0]["lr"] - initial_lr) / warmup_iters
|
||||||
|
|
||||||
for epoch in range(n_epochs):
|
for epoch in range(n_epochs):
|
||||||
model.train()
|
model.train()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user