mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-18 19:40:30 +00:00
Add llama2 unit tests (#372)
* add llama2 unit tests * update * updates * updates * update file path * update requirements file * rmsnorm test * update
This commit is contained in:
parent
a6d8e93da3
commit
b56d0b2942
2
.github/workflows/basic-tests-linux.yml
vendored
2
.github/workflows/basic-tests-linux.yml
vendored
@ -35,12 +35,14 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest nbval
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
|
||||
|
||||
- name: Test Selected Python Scripts
|
||||
run: |
|
||||
pytest setup/02_installing-python-libraries/tests.py
|
||||
pytest ch04/01_main-chapter-code/tests.py
|
||||
pytest ch05/01_main-chapter-code/tests.py
|
||||
pytest ch05/07_gpt_to_llama/tests/tests.py
|
||||
pytest ch06/01_main-chapter-code/tests.py
|
||||
|
||||
- name: Validate Selected Jupyter Notebooks
|
||||
|
2
.github/workflows/basic-tests-macos.yml
vendored
2
.github/workflows/basic-tests-macos.yml
vendored
@ -35,12 +35,14 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest nbval
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
|
||||
|
||||
- name: Test Selected Python Scripts
|
||||
run: |
|
||||
pytest setup/02_installing-python-libraries/tests.py
|
||||
pytest ch04/01_main-chapter-code/tests.py
|
||||
pytest ch05/01_main-chapter-code/tests.py
|
||||
pytest ch05/07_gpt_to_llama/tests/tests.py
|
||||
pytest ch06/01_main-chapter-code/tests.py
|
||||
|
||||
- name: Validate Selected Jupyter Notebooks
|
||||
|
@ -39,12 +39,14 @@ jobs:
|
||||
pip install pytest nbval
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
pip install torch==${{ matrix.pytorch-version }}
|
||||
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
|
||||
|
||||
- name: Test Selected Python Scripts
|
||||
run: |
|
||||
pytest setup/02_installing-python-libraries/tests.py
|
||||
pytest ch04/01_main-chapter-code/tests.py
|
||||
pytest ch05/01_main-chapter-code/tests.py
|
||||
pytest ch05/07_gpt_to_llama/tests/tests.py
|
||||
pytest ch06/01_main-chapter-code/tests.py
|
||||
|
||||
- name: Validate Selected Jupyter Notebooks
|
||||
|
2
.github/workflows/basic-tests-windows.yml
vendored
2
.github/workflows/basic-tests-windows.yml
vendored
@ -38,6 +38,7 @@ jobs:
|
||||
pip install pytest nbval
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
pip install matplotlib==3.9.0
|
||||
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
|
||||
|
||||
- name: Test Selected Python Scripts
|
||||
shell: bash
|
||||
@ -45,6 +46,7 @@ jobs:
|
||||
pytest setup/02_installing-python-libraries/tests.py
|
||||
pytest ch04/01_main-chapter-code/tests.py
|
||||
pytest ch05/01_main-chapter-code/tests.py
|
||||
pytest ch05/07_gpt_to_llama/tests/tests.py
|
||||
pytest ch06/01_main-chapter-code/tests.py
|
||||
|
||||
- name: Validate Selected Jupyter Notebooks
|
||||
|
@ -76,7 +76,7 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
|
||||
"outputId": "d0fc89be-74a3-40d0-bc4d-7f6f1febf2cd"
|
||||
"outputId": "7ce8fe41-1c24-4f0b-a8d9-352b4af1b46b"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -84,7 +84,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"huggingface_hub version: 0.24.7\n",
|
||||
"sentencepiece version: 0.1.99\n",
|
||||
"sentencepiece version: 0.2.0\n",
|
||||
"torch version: 2.4.1+cu121\n"
|
||||
]
|
||||
}
|
||||
@ -421,41 +421,39 @@
|
||||
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
|
||||
"\n",
|
||||
" # Compute the inverse frequencies\n",
|
||||
" inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))\n",
|
||||
" inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
|
||||
"\n",
|
||||
" # Generate position indices\n",
|
||||
" positions = torch.arange(context_length)\n",
|
||||
"\n",
|
||||
" # Compute the angles using inverse frequencies and positions\n",
|
||||
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, emb_dim // 2)\n",
|
||||
" # Compute the angles\n",
|
||||
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
|
||||
"\n",
|
||||
" # Precompute sine and cosine of the angles\n",
|
||||
" sin = torch.sin(angles) # Shape: (context_length, emb_dim // 2)\n",
|
||||
" cos = torch.cos(angles) # Shape: (context_length, emb_dim // 2)\n",
|
||||
" # Expand angles to match the head_dim\n",
|
||||
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
|
||||
"\n",
|
||||
" return sin, cos\n",
|
||||
" # Precompute sine and cosine\n",
|
||||
" cos = torch.cos(angles)\n",
|
||||
" sin = torch.sin(angles)\n",
|
||||
"\n",
|
||||
" return cos, sin\n",
|
||||
"\n",
|
||||
"def compute_rope(x, sin, cos):\n",
|
||||
"def compute_rope(x, cos, sin):\n",
|
||||
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
|
||||
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
|
||||
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
|
||||
"\n",
|
||||
" # Split x into even and odd parts\n",
|
||||
" x1 = x[..., ::2] # Shape: (batch_size, num_heads, seq_len, head_dim // 2)\n",
|
||||
" x2 = x[..., 1::2]\n",
|
||||
" # Split x into first half and second half\n",
|
||||
" x1 = x[..., : head_dim // 2] # First half\n",
|
||||
" x2 = x[..., head_dim // 2 :] # Second half\n",
|
||||
"\n",
|
||||
" # Ensure sin and cos have correct shapes\n",
|
||||
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)\n",
|
||||
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
|
||||
" # Adjust sin and cos shapes\n",
|
||||
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
|
||||
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
|
||||
"\n",
|
||||
" # Apply the rotary transformation\n",
|
||||
" x_rotated_0 = x1 * cos - x2 * sin\n",
|
||||
" x_rotated_1 = x1 * sin + x2 * cos\n",
|
||||
"\n",
|
||||
" # Interleave x_rotated_0 and x_rotated_1\n",
|
||||
" x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1)\n",
|
||||
" x_rotated = x_rotated.flatten(-2)\n",
|
||||
" rotated = torch.cat((-x2, x1), dim=-1)\n",
|
||||
" x_rotated = (x * cos) + (rotated * sin)\n",
|
||||
"\n",
|
||||
" return x_rotated.to(dtype=x.dtype)"
|
||||
]
|
||||
@ -486,7 +484,7 @@
|
||||
"head_dim = 16\n",
|
||||
"\n",
|
||||
"# Instantiate RoPE parameters\n",
|
||||
"sin, cos = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
|
||||
"cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
|
||||
"\n",
|
||||
"# Dummy query and key tensors\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
@ -494,8 +492,8 @@
|
||||
"keys = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
|
||||
"\n",
|
||||
"# Apply rotary position embeddings\n",
|
||||
"queries_rot = compute_rope(queries, sin, cos)\n",
|
||||
"keys_rot = compute_rope(keys, sin, cos)"
|
||||
"queries_rot = compute_rope(queries, cos, sin)\n",
|
||||
"keys_rot = compute_rope(keys, cos, sin)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -554,9 +552,9 @@
|
||||
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
||||
"\n",
|
||||
" ################################### NEW ###################################\n",
|
||||
" sin, cos = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
|
||||
" self.register_buffer(\"sin\", sin)\n",
|
||||
" cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
|
||||
" self.register_buffer(\"cos\", cos)\n",
|
||||
" self.register_buffer(\"sin\", sin)\n",
|
||||
" ###########################################################################\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -736,7 +734,7 @@
|
||||
"cell_type": "markdown",
|
||||
"id": "ba5d991a-559b-47be-96f4-31b881ab2da8",
|
||||
"metadata": {
|
||||
"id": "aa79780d-74a8-4ee0-934a-9ad63205a02e"
|
||||
"id": "ba5d991a-559b-47be-96f4-31b881ab2da8"
|
||||
},
|
||||
"source": [
|
||||
"- As you may recall from [chapter 5](../01_main-chapter-code/ch05.ipynb), the `TransformerBlock` is a repeated block within the main model\n",
|
||||
@ -918,7 +916,7 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "6079f747-8f20-4c6b-8d38-7156f1101729",
|
||||
"outputId": "78ab929e-ac78-4b16-ddb1-704d45ee69a8"
|
||||
"outputId": "1ca50091-a20c-4a44-b806-9985a5e64135"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -954,15 +952,15 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
|
||||
"outputId": "c0cbdcc8-dc46-44f7-a800-fbe888a3f9e9"
|
||||
"outputId": "b157b5ac-d37c-4b71-f609-45a91f7ed93a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"float32 (PyTorch default): 52.27 GB\n",
|
||||
"bfloat16: 26.13 GB\n"
|
||||
"float32 (PyTorch default): 52.33 GB\n",
|
||||
"bfloat16: 26.17 GB\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1087,7 +1085,7 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "3357a230-b678-4691-a238-257ee4e80185",
|
||||
"outputId": "d326d32c-fa8d-4f2b-84d5-a1b8f35dd387"
|
||||
"outputId": "7d4adc4b-53cf-4099-a45f-2fb4fd25edc4"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -1131,7 +1129,7 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
|
||||
"outputId": "82bc5037-c86c-46c2-b374-269f9d09599a"
|
||||
"outputId": "aa18fccc-6533-4446-f57b-546068ad518c"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -1213,7 +1211,7 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
|
||||
"outputId": "d733bc0a-5136-4c33-d70d-36056f1e8329"
|
||||
"outputId": "cbc53f67-a77a-40c9-ed2d-c6f8be066cfb"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -1221,7 +1219,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output text:\n",
|
||||
" Every effort movesαllRadius deletingpretccappedRadius zas Parte Material Ку términчной herousztusllRadiusotto кра liberotto siguientesagnost#{ (@topicquez restored log\n"
|
||||
" Every effort movesαfdmsdn coatELDâte eer tagsיśćinu Lundmysq eer napinu LundANCEHEAD ner}}}رible one}}}رible one puts Dan\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1322,7 +1320,7 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
|
||||
"outputId": "01721809-ace1-4a7a-ab54-8fad2e8f54a6"
|
||||
"outputId": "351029ce-b4c0-4d39-8e0e-7e7f44d25647"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -1457,7 +1455,7 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "240987e8-a023-462e-9376-9edfb27559ec",
|
||||
"outputId": "59830005-42af-406b-c836-38a8f2d7b961"
|
||||
"outputId": "3fa7a77a-6203-4d8a-bdaa-afce1f504adf"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -1465,7 +1463,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output text:\n",
|
||||
" Every effort has been made to ensure that the information contained in this website is accurate and up to date. However, the information is provided without any warranty\n"
|
||||
" Every effort has been made to ensure that the information contained in this website is correct and up to date and accurate at the time of publication\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1475,7 +1473,7 @@
|
||||
"token_ids = generate(\n",
|
||||
" model=model,\n",
|
||||
" idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
|
||||
" max_new_tokens=30,\n",
|
||||
" max_new_tokens=25,\n",
|
||||
" context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
|
||||
" top_k=1,\n",
|
||||
" temperature=0.\n",
|
||||
@ -1496,14 +1494,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 31,
|
||||
"id": "nbvAV7vaz6yc",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "nbvAV7vaz6yc",
|
||||
"outputId": "faa930dc-0db2-4095-b395-f97baef08903"
|
||||
"outputId": "bd4cae4d-5d5f-4f64-ea37-b979ef2c86bb"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -1512,7 +1510,7 @@
|
||||
"text": [
|
||||
"Output text:\n",
|
||||
" What do llamas eat?\n",
|
||||
"Llamas are herbivores, which means they eat plants. They eat grass, leaves, and hay.\n"
|
||||
"Llamas are herbivores, which means they eat grass, leaves, grasses, and they eat grass\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
1
ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
Normal file
1
ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
Normal file
@ -0,0 +1 @@
|
||||
transformers>=4.44.2
|
114
ch05/07_gpt_to_llama/tests/tests.py
Normal file
114
ch05/07_gpt_to_llama/tests/tests.py
Normal file
@ -0,0 +1,114 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import nbformat
|
||||
import torch
|
||||
import pytest
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
# File for internal use (unit tests)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def notebook():
|
||||
def import_definitions_from_notebook(fullname, names):
|
||||
# Get the directory of the current test file
|
||||
current_dir = os.path.dirname(__file__)
|
||||
path = os.path.join(current_dir, "..", fullname + ".ipynb")
|
||||
path = os.path.normpath(path)
|
||||
|
||||
# Load the notebook
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"Notebook file not found at: {path}")
|
||||
|
||||
with io.open(path, "r", encoding="utf-8") as f:
|
||||
nb = nbformat.read(f, as_version=4)
|
||||
|
||||
# Create a module to store the imported functions and classes
|
||||
mod = types.ModuleType(fullname)
|
||||
sys.modules[fullname] = mod
|
||||
|
||||
# Go through the notebook cells and only execute function or class definitions
|
||||
for cell in nb.cells:
|
||||
if cell.cell_type == "code":
|
||||
cell_code = cell.source
|
||||
for name in names:
|
||||
# Check for function or class definitions
|
||||
if f"def {name}" in cell_code or f"class {name}" in cell_code:
|
||||
exec(cell_code, mod.__dict__)
|
||||
return mod
|
||||
|
||||
# Specify the notebook name and functions/classes to import
|
||||
fullname = "converting-gpt-to-llama2"
|
||||
names = ["precompute_rope_params", "compute_rope", "SiLU", "RMSNorm"]
|
||||
|
||||
# Import the required functions and classes from the notebook
|
||||
return import_definitions_from_notebook(fullname, names)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_seed():
|
||||
torch.manual_seed(123)
|
||||
|
||||
|
||||
def test_rope(notebook):
|
||||
# Settings
|
||||
batch_size = 1
|
||||
context_len = 5
|
||||
num_heads = 4
|
||||
head_dim = 16
|
||||
|
||||
# Instantiate RoPE parameters
|
||||
cos, sin = notebook.precompute_rope_params(head_dim=head_dim, context_length=context_len)
|
||||
|
||||
# Dummy query and key tensors
|
||||
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
queries_rot = notebook.compute_rope(queries, cos, sin)
|
||||
keys_rot = notebook.compute_rope(keys, cos, sin)
|
||||
|
||||
class RoPEConfig:
|
||||
rope_type = "default"
|
||||
rope_scaling = None
|
||||
factor = 1.0
|
||||
dim: int = head_dim
|
||||
rope_theta = 10000
|
||||
max_position_embeddings: int = 4096
|
||||
hidden_size = head_dim * num_heads
|
||||
num_attention_heads = num_heads
|
||||
|
||||
config = RoPEConfig()
|
||||
|
||||
rot_emb = LlamaRotaryEmbedding(config=config)
|
||||
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
||||
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
||||
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
||||
|
||||
torch.testing.assert_close(sin, ref_sin.squeeze(0))
|
||||
torch.testing.assert_close(cos, ref_cos.squeeze(0))
|
||||
torch.testing.assert_close(keys_rot, ref_keys_rot)
|
||||
torch.testing.assert_close(queries_rot, ref_queries_rot)
|
||||
|
||||
|
||||
def test_silu(notebook):
|
||||
example_batch = torch.randn(2, 3, 4)
|
||||
silu = notebook.SiLU()
|
||||
assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
|
||||
def test_rmsnorm(notebook):
|
||||
example_batch = torch.randn(2, 3, 4)
|
||||
rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1])
|
||||
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)
|
||||
|
||||
assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))
|
Loading…
x
Reference in New Issue
Block a user