mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-30 17:29:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			43 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			43 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
 | |
| # Source for "Build a Large Language Model From Scratch"
 | |
| #   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 | |
| # Code: https://github.com/rasbt/LLMs-from-scratch
 | |
| 
 | |
| import torch
 | |
| import math
 | |
| 
 | |
| 
 | |
| class LoRALayer(torch.nn.Module):
 | |
|     def __init__(self, in_dim, out_dim, rank, alpha):
 | |
|         super().__init__()
 | |
|         self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
 | |
|         torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))  # similar to standard weight initialization
 | |
|         self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
 | |
|         self.alpha = alpha
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x = self.alpha * (x @ self.A @ self.B)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class LinearWithLoRA(torch.nn.Module):
 | |
|     def __init__(self, linear, rank, alpha):
 | |
|         super().__init__()
 | |
|         self.linear = linear
 | |
|         self.lora = LoRALayer(
 | |
|             linear.in_features, linear.out_features, rank, alpha
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return self.linear(x) + self.lora(x)
 | |
| 
 | |
| 
 | |
| def replace_linear_with_lora(model, rank, alpha):
 | |
|     for name, module in model.named_children():
 | |
|         if isinstance(module, torch.nn.Linear):
 | |
|             # Replace the Linear layer with LinearWithLoRA
 | |
|             setattr(model, name, LinearWithLoRA(module, rank, alpha))
 | |
|         else:
 | |
|             # Recursively apply the same function to child modules
 | |
|             replace_linear_with_lora(module, rank, alpha)
 | 
