mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 18:00:08 +00:00 
			
		
		
		
	
		
			
	
	
		
			43 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			43 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | ||
|  | # Source for "Build a Large Language Model From Scratch" | ||
|  | #   - https://www.manning.com/books/build-a-large-language-model-from-scratch | ||
|  | # Code: https://github.com/rasbt/LLMs-from-scratch | ||
|  | 
 | ||
|  | import torch | ||
|  | import math | ||
|  | 
 | ||
|  | 
 | ||
|  | class LoRALayer(torch.nn.Module): | ||
|  |     def __init__(self, in_dim, out_dim, rank, alpha): | ||
|  |         super().__init__() | ||
|  |         self.A = torch.nn.Parameter(torch.empty(in_dim, rank)) | ||
|  |         torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))  # similar to standard weight initialization | ||
|  |         self.B = torch.nn.Parameter(torch.zeros(rank, out_dim)) | ||
|  |         self.alpha = alpha | ||
|  | 
 | ||
|  |     def forward(self, x): | ||
|  |         x = self.alpha * (x @ self.A @ self.B) | ||
|  |         return x | ||
|  | 
 | ||
|  | 
 | ||
|  | class LinearWithLoRA(torch.nn.Module): | ||
|  |     def __init__(self, linear, rank, alpha): | ||
|  |         super().__init__() | ||
|  |         self.linear = linear | ||
|  |         self.lora = LoRALayer( | ||
|  |             linear.in_features, linear.out_features, rank, alpha | ||
|  |         ) | ||
|  | 
 | ||
|  |     def forward(self, x): | ||
|  |         return self.linear(x) + self.lora(x) | ||
|  | 
 | ||
|  | 
 | ||
|  | def replace_linear_with_lora(model, rank, alpha): | ||
|  |     for name, module in model.named_children(): | ||
|  |         if isinstance(module, torch.nn.Linear): | ||
|  |             # Replace the Linear layer with LinearWithLoRA | ||
|  |             setattr(model, name, LinearWithLoRA(module, rank, alpha)) | ||
|  |         else: | ||
|  |             # Recursively apply the same function to child modules | ||
|  |             replace_linear_with_lora(module, rank, alpha) |