mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-30 17:29:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			95 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			95 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
 | |
| # Source for "Build a Large Language Model From Scratch"
 | |
| #   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 | |
| # Code: https://github.com/rasbt/LLMs-from-scratch
 | |
| 
 | |
| from .ch05 import calc_loss_batch, evaluate_model, generate_and_print_sample
 | |
| 
 | |
| import math
 | |
| import torch
 | |
| 
 | |
| 
 | |
| def find_highest_gradient(model):
 | |
|     max_grad = None
 | |
|     for param in model.parameters():
 | |
|         if param.grad is not None:
 | |
|             grad_values = param.grad.data.flatten()
 | |
|             max_grad_param = grad_values.max()
 | |
|             if max_grad is None or max_grad_param > max_grad:
 | |
|                 max_grad = max_grad_param
 | |
|     return max_grad
 | |
| 
 | |
| 
 | |
| def train_model(model, train_loader, val_loader, optimizer, device,
 | |
|                 n_epochs, eval_freq, eval_iter, start_context, tokenizer,
 | |
|                 warmup_steps, initial_lr=3e-05, min_lr=1e-6, orig_book_version=False):
 | |
| 
 | |
|     train_losses, val_losses, track_tokens_seen, track_lrs = [], [], [], []
 | |
|     tokens_seen, global_step = 0, -1
 | |
| 
 | |
|     # Retrieve the maximum learning rate from the optimizer
 | |
|     peak_lr = optimizer.param_groups[0]["lr"]
 | |
| 
 | |
|     # Calculate the total number of iterations in the training process
 | |
|     total_training_steps = len(train_loader) * n_epochs
 | |
| 
 | |
|     # Calculate the learning rate increment during the warmup phase
 | |
|     lr_increment = (peak_lr - initial_lr) / warmup_steps
 | |
| 
 | |
|     for epoch in range(n_epochs):
 | |
|         model.train()
 | |
|         for input_batch, target_batch in train_loader:
 | |
|             optimizer.zero_grad()
 | |
|             global_step += 1
 | |
| 
 | |
|             # Adjust the learning rate based on the current phase (warmup or cosine annealing)
 | |
|             if global_step < warmup_steps:
 | |
|                 # Linear warmup
 | |
|                 lr = initial_lr + global_step * lr_increment
 | |
|             else:
 | |
|                 # Cosine annealing after warmup
 | |
|                 progress = ((global_step - warmup_steps) /
 | |
|                             (total_training_steps - warmup_steps))
 | |
|                 lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
 | |
| 
 | |
|             # Apply the calculated learning rate to the optimizer
 | |
|             for param_group in optimizer.param_groups:
 | |
|                 param_group["lr"] = lr
 | |
|             track_lrs.append(lr)  # Store the current learning rate
 | |
| 
 | |
|             # Calculate and backpropagate the loss
 | |
|             loss = calc_loss_batch(input_batch, target_batch, model, device)
 | |
|             loss.backward()
 | |
| 
 | |
|             # Apply gradient clipping after the warmup phase to avoid exploding gradients
 | |
|             if orig_book_version:
 | |
|                 if global_step > warmup_steps:
 | |
|                     torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 | |
|             else:
 | |
|                 if global_step >= warmup_steps:  # the book originally used global_step > warmup_steps, which lead to a skipped clipping step after warmup
 | |
|                     torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 | |
| 
 | |
|             optimizer.step()
 | |
|             tokens_seen += input_batch.numel()
 | |
| 
 | |
|             # Periodically evaluate the model on the training and validation sets
 | |
|             if global_step % eval_freq == 0:
 | |
|                 train_loss, val_loss = evaluate_model(
 | |
|                     model, train_loader, val_loader,
 | |
|                     device, eval_iter
 | |
|                 )
 | |
|                 train_losses.append(train_loss)
 | |
|                 val_losses.append(val_loss)
 | |
|                 track_tokens_seen.append(tokens_seen)
 | |
|                 # Print the current losses
 | |
|                 print(f"Ep {epoch+1} (Iter {global_step:06d}): "
 | |
|                       f"Train loss {train_loss:.3f}, "
 | |
|                       f"Val loss {val_loss:.3f}")
 | |
| 
 | |
|         # Generate and print a sample from the model to monitor progress
 | |
|         generate_and_print_sample(
 | |
|             model, tokenizer, device, start_context
 | |
|         )
 | |
| 
 | |
|     return train_losses, val_losses, track_tokens_seen, track_lrs
 | 
