# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). # Source for "Build a Large Language Model From Scratch" # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch from .ch05 import calc_loss_batch, evaluate_model, generate_and_print_sample import math import torch def find_highest_gradient(model): max_grad = None for param in model.parameters(): if param.grad is not None: grad_values = param.grad.data.flatten() max_grad_param = grad_values.max() if max_grad is None or max_grad_param > max_grad: max_grad = max_grad_param return max_grad def train_model(model, train_loader, val_loader, optimizer, device, n_epochs, eval_freq, eval_iter, start_context, tokenizer, warmup_steps, initial_lr=3e-05, min_lr=1e-6, orig_book_version=False): train_losses, val_losses, track_tokens_seen, track_lrs = [], [], [], [] tokens_seen, global_step = 0, -1 # Retrieve the maximum learning rate from the optimizer peak_lr = optimizer.param_groups[0]["lr"] # Calculate the total number of iterations in the training process total_training_steps = len(train_loader) * n_epochs # Calculate the learning rate increment during the warmup phase lr_increment = (peak_lr - initial_lr) / warmup_steps for epoch in range(n_epochs): model.train() for input_batch, target_batch in train_loader: optimizer.zero_grad() global_step += 1 # Adjust the learning rate based on the current phase (warmup or cosine annealing) if global_step < warmup_steps: # Linear warmup lr = initial_lr + global_step * lr_increment else: # Cosine annealing after warmup progress = ((global_step - warmup_steps) / (total_training_steps - warmup_steps)) lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress)) # Apply the calculated learning rate to the optimizer for param_group in optimizer.param_groups: param_group["lr"] = lr track_lrs.append(lr) # Store the current learning rate # Calculate and backpropagate the loss loss = calc_loss_batch(input_batch, target_batch, model, device) loss.backward() # Apply gradient clipping after the warmup phase to avoid exploding gradients if orig_book_version: if global_step > warmup_steps: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) else: if global_step >= warmup_steps: # the book originally used global_step > warmup_steps, which lead to a skipped clipping step after warmup torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() tokens_seen += input_batch.numel() # Periodically evaluate the model on the training and validation sets if global_step % eval_freq == 0: train_loss, val_loss = evaluate_model( model, train_loader, val_loader, device, eval_iter ) train_losses.append(train_loss) val_losses.append(val_loss) track_tokens_seen.append(tokens_seen) # Print the current losses print(f"Ep {epoch+1} (Iter {global_step:06d}): " f"Train loss {train_loss:.3f}, " f"Val loss {val_loss:.3f}") # Generate and print a sample from the model to monitor progress generate_and_print_sample( model, tokenizer, device, start_context ) return train_losses, val_losses, track_tokens_seen, track_lrs