mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-30 17:29:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			147 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			147 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import io
 | |
| import os
 | |
| import sys
 | |
| import types
 | |
| import nbformat
 | |
| import torch
 | |
| import pytest
 | |
| from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
 | |
| 
 | |
| 
 | |
| # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
 | |
| # Source for "Build a Large Language Model From Scratch"
 | |
| #   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 | |
| # Code: https://github.com/rasbt/LLMs-from-scratch
 | |
| 
 | |
| # File for internal use (unit tests)
 | |
| 
 | |
| 
 | |
| @pytest.fixture(scope="module")
 | |
| def notebook():
 | |
|     def import_definitions_from_notebook(fullname, names):
 | |
|         # Get the directory of the current test file
 | |
|         current_dir = os.path.dirname(__file__)
 | |
|         path = os.path.join(current_dir, "..", fullname + ".ipynb")
 | |
|         path = os.path.normpath(path)
 | |
| 
 | |
|         # Load the notebook
 | |
|         if not os.path.exists(path):
 | |
|             raise FileNotFoundError(f"Notebook file not found at: {path}")
 | |
| 
 | |
|         with io.open(path, "r", encoding="utf-8") as f:
 | |
|             nb = nbformat.read(f, as_version=4)
 | |
| 
 | |
|         # Create a module to store the imported functions and classes
 | |
|         mod = types.ModuleType(fullname)
 | |
|         sys.modules[fullname] = mod
 | |
| 
 | |
|         # Go through the notebook cells and only execute function or class definitions
 | |
|         for cell in nb.cells:
 | |
|             if cell.cell_type == "code":
 | |
|                 cell_code = cell.source
 | |
|                 for name in names:
 | |
|                     # Check for function or class definitions
 | |
|                     if f"def {name}" in cell_code or f"class {name}" in cell_code:
 | |
|                         exec(cell_code, mod.__dict__)
 | |
|         return mod
 | |
| 
 | |
|     # Specify the notebook name and functions/classes to import
 | |
|     fullname = "converting-gpt-to-llama2"
 | |
|     names = ["precompute_rope_params", "compute_rope", "SiLU", "RMSNorm"]
 | |
| 
 | |
|     # Import the required functions and classes from the notebook
 | |
|     return import_definitions_from_notebook(fullname, names)
 | |
| 
 | |
| 
 | |
| @pytest.fixture(autouse=True)
 | |
| def set_seed():
 | |
|     torch.manual_seed(123)
 | |
| 
 | |
| 
 | |
| def test_rope_llama2(notebook):
 | |
|     # Settings
 | |
|     batch_size = 1
 | |
|     context_len = 4096
 | |
|     num_heads = 4
 | |
|     head_dim = 16
 | |
| 
 | |
|     # Instantiate RoPE parameters
 | |
|     cos, sin = notebook.precompute_rope_params(head_dim=head_dim, context_length=context_len)
 | |
| 
 | |
|     # Dummy query and key tensors
 | |
|     queries = torch.randn(batch_size, num_heads, context_len, head_dim)
 | |
|     keys = torch.randn(batch_size, num_heads, context_len, head_dim)
 | |
| 
 | |
|     # Apply rotary position embeddings
 | |
|     queries_rot = notebook.compute_rope(queries, cos, sin)
 | |
|     keys_rot = notebook.compute_rope(keys, cos, sin)
 | |
| 
 | |
|     rot_emb = LlamaRotaryEmbedding(
 | |
|         dim=head_dim,
 | |
|         max_position_embeddings=context_len,
 | |
|         base=10_000
 | |
|     )
 | |
| 
 | |
|     position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
 | |
|     ref_cos, ref_sin = rot_emb(queries, position_ids)
 | |
|     ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
 | |
| 
 | |
|     torch.testing.assert_close(sin, ref_sin.squeeze(0))
 | |
|     torch.testing.assert_close(cos, ref_cos.squeeze(0))
 | |
|     torch.testing.assert_close(keys_rot, ref_keys_rot)
 | |
|     torch.testing.assert_close(queries_rot, ref_queries_rot)
 | |
| 
 | |
| 
 | |
| def test_rope_llama3(notebook):
 | |
|     # Settings
 | |
|     batch_size = 1
 | |
|     context_len = 8192
 | |
|     num_heads = 4
 | |
|     head_dim = 16
 | |
|     theta_base = 50_000
 | |
| 
 | |
|     # Instantiate RoPE parameters
 | |
|     cos, sin = notebook.precompute_rope_params(
 | |
|         head_dim=head_dim,
 | |
|         context_length=context_len,
 | |
|         theta_base=theta_base
 | |
|     )
 | |
| 
 | |
|     # Dummy query and key tensors
 | |
|     queries = torch.randn(batch_size, num_heads, context_len, head_dim)
 | |
|     keys = torch.randn(batch_size, num_heads, context_len, head_dim)
 | |
| 
 | |
|     # Apply rotary position embeddings
 | |
|     queries_rot = notebook.compute_rope(queries, cos, sin)
 | |
|     keys_rot = notebook.compute_rope(keys, cos, sin)
 | |
| 
 | |
|     rot_emb = LlamaRotaryEmbedding(
 | |
|         dim=head_dim,
 | |
|         max_position_embeddings=context_len,
 | |
|         base=theta_base
 | |
|     )
 | |
| 
 | |
|     position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
 | |
|     ref_cos, ref_sin = rot_emb(queries, position_ids)
 | |
|     ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
 | |
| 
 | |
|     torch.testing.assert_close(sin, ref_sin.squeeze(0))
 | |
|     torch.testing.assert_close(cos, ref_cos.squeeze(0))
 | |
|     torch.testing.assert_close(keys_rot, ref_keys_rot)
 | |
|     torch.testing.assert_close(queries_rot, ref_queries_rot)
 | |
| 
 | |
| 
 | |
| def test_silu(notebook):
 | |
|     example_batch = torch.randn(2, 3, 4)
 | |
|     silu = notebook.SiLU()
 | |
|     assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))
 | |
| 
 | |
| 
 | |
| @pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
 | |
| def test_rmsnorm(notebook):
 | |
|     example_batch = torch.randn(2, 3, 4)
 | |
|     rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
 | |
|     rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)
 | |
| 
 | |
|     assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))
 | 
