diff --git a/ch05/05_bonus_hparam_tuning/hparam_search.py b/ch05/05_bonus_hparam_tuning/hparam_search.py index 4f2a2ce..e2e68e2 100644 --- a/ch05/05_bonus_hparam_tuning/hparam_search.py +++ b/ch05/05_bonus_hparam_tuning/hparam_search.py @@ -64,8 +64,7 @@ def evaluate_model(model, train_loader, val_loader, device, eval_iter): def train_model(model, train_loader, val_loader, optimizer, device, - n_epochs, eval_freq, eval_iter, - encoded_start_context, tokenizer, warmup_iters=10, + n_epochs, eval_iter, warmup_iters=10, initial_lr=3e-05, min_lr=1e-6): global_step = 0 @@ -192,9 +191,7 @@ if __name__ == "__main__": train_loss, val_loss = train_model( model, train_loader, val_loader, optimizer, device, n_epochs=HPARAM_CONFIG["n_epochs"], - eval_freq=5, eval_iter=1, - encoded_start_context=encoded_tensor, - tokenizer=tokenizer, + eval_iter=1, warmup_iters=HPARAM_CONFIG["warmup_iters"], initial_lr=HPARAM_CONFIG["initial_lr"], min_lr=HPARAM_CONFIG["min_lr"]