Add llama2 unit tests (#372)

* add llama2 unit tests

* update

* updates

* updates

* update file path

* update requirements file

* rmsnorm test

* update
This commit is contained in:
Sebastian Raschka 2024-09-25 19:40:36 -05:00 committed by GitHub
parent a6d8e93da3
commit b56d0b2942
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 164 additions and 43 deletions

View File

@ -35,12 +35,14 @@ jobs:
python -m pip install --upgrade pip
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks

View File

@ -35,12 +35,14 @@ jobs:
python -m pip install --upgrade pip
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks

View File

@ -39,12 +39,14 @@ jobs:
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install torch==${{ matrix.pytorch-version }}
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks

View File

@ -38,6 +38,7 @@ jobs:
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install matplotlib==3.9.0
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- name: Test Selected Python Scripts
shell: bash
@ -45,6 +46,7 @@ jobs:
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks

View File

@ -76,7 +76,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
"outputId": "d0fc89be-74a3-40d0-bc4d-7f6f1febf2cd"
"outputId": "7ce8fe41-1c24-4f0b-a8d9-352b4af1b46b"
},
"outputs": [
{
@ -84,7 +84,7 @@
"output_type": "stream",
"text": [
"huggingface_hub version: 0.24.7\n",
"sentencepiece version: 0.1.99\n",
"sentencepiece version: 0.2.0\n",
"torch version: 2.4.1+cu121\n"
]
}
@ -421,41 +421,39 @@
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))\n",
" inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length)\n",
"\n",
" # Compute the angles using inverse frequencies and positions\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, emb_dim // 2)\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Precompute sine and cosine of the angles\n",
" sin = torch.sin(angles) # Shape: (context_length, emb_dim // 2)\n",
" cos = torch.cos(angles) # Shape: (context_length, emb_dim // 2)\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
"\n",
" return sin, cos\n",
" # Precompute sine and cosine\n",
" cos = torch.cos(angles)\n",
" sin = torch.sin(angles)\n",
"\n",
" return cos, sin\n",
"\n",
"def compute_rope(x, sin, cos):\n",
"def compute_rope(x, cos, sin):\n",
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
"\n",
" # Split x into even and odd parts\n",
" x1 = x[..., ::2] # Shape: (batch_size, num_heads, seq_len, head_dim // 2)\n",
" x2 = x[..., 1::2]\n",
" # Split x into first half and second half\n",
" x1 = x[..., : head_dim // 2] # First half\n",
" x2 = x[..., head_dim // 2 :] # Second half\n",
"\n",
" # Ensure sin and cos have correct shapes\n",
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)\n",
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
" # Adjust sin and cos shapes\n",
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
"\n",
" # Apply the rotary transformation\n",
" x_rotated_0 = x1 * cos - x2 * sin\n",
" x_rotated_1 = x1 * sin + x2 * cos\n",
"\n",
" # Interleave x_rotated_0 and x_rotated_1\n",
" x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1)\n",
" x_rotated = x_rotated.flatten(-2)\n",
" rotated = torch.cat((-x2, x1), dim=-1)\n",
" x_rotated = (x * cos) + (rotated * sin)\n",
"\n",
" return x_rotated.to(dtype=x.dtype)"
]
@ -486,7 +484,7 @@
"head_dim = 16\n",
"\n",
"# Instantiate RoPE parameters\n",
"sin, cos = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
"cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
"\n",
"# Dummy query and key tensors\n",
"torch.manual_seed(123)\n",
@ -494,8 +492,8 @@
"keys = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
"\n",
"# Apply rotary position embeddings\n",
"queries_rot = compute_rope(queries, sin, cos)\n",
"keys_rot = compute_rope(keys, sin, cos)"
"queries_rot = compute_rope(queries, cos, sin)\n",
"keys_rot = compute_rope(keys, cos, sin)"
]
},
{
@ -554,9 +552,9 @@
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" ################################### NEW ###################################\n",
" sin, cos = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
" self.register_buffer(\"sin\", sin)\n",
" cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
" self.register_buffer(\"cos\", cos)\n",
" self.register_buffer(\"sin\", sin)\n",
" ###########################################################################\n",
"\n",
"\n",
@ -736,7 +734,7 @@
"cell_type": "markdown",
"id": "ba5d991a-559b-47be-96f4-31b881ab2da8",
"metadata": {
"id": "aa79780d-74a8-4ee0-934a-9ad63205a02e"
"id": "ba5d991a-559b-47be-96f4-31b881ab2da8"
},
"source": [
"- As you may recall from [chapter 5](../01_main-chapter-code/ch05.ipynb), the `TransformerBlock` is a repeated block within the main model\n",
@ -918,7 +916,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "6079f747-8f20-4c6b-8d38-7156f1101729",
"outputId": "78ab929e-ac78-4b16-ddb1-704d45ee69a8"
"outputId": "1ca50091-a20c-4a44-b806-9985a5e64135"
},
"outputs": [
{
@ -954,15 +952,15 @@
"base_uri": "https://localhost:8080/"
},
"id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
"outputId": "c0cbdcc8-dc46-44f7-a800-fbe888a3f9e9"
"outputId": "b157b5ac-d37c-4b71-f609-45a91f7ed93a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float32 (PyTorch default): 52.27 GB\n",
"bfloat16: 26.13 GB\n"
"float32 (PyTorch default): 52.33 GB\n",
"bfloat16: 26.17 GB\n"
]
}
],
@ -1087,7 +1085,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "3357a230-b678-4691-a238-257ee4e80185",
"outputId": "d326d32c-fa8d-4f2b-84d5-a1b8f35dd387"
"outputId": "7d4adc4b-53cf-4099-a45f-2fb4fd25edc4"
},
"outputs": [
{
@ -1131,7 +1129,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
"outputId": "82bc5037-c86c-46c2-b374-269f9d09599a"
"outputId": "aa18fccc-6533-4446-f57b-546068ad518c"
},
"outputs": [
{
@ -1213,7 +1211,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
"outputId": "d733bc0a-5136-4c33-d70d-36056f1e8329"
"outputId": "cbc53f67-a77a-40c9-ed2d-c6f8be066cfb"
},
"outputs": [
{
@ -1221,7 +1219,7 @@
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort movesαllRadius deletingpretccappedRadius zas Parte Material Ку términчной herousztusllRadiusotto кра liberotto siguientesagnost#{ (@topicquez restored log\n"
" Every effort movesαfdmsdn coatELDâte eer tagsיśćinu Lundmysq eer napinu LundANCEHEAD ner}}}رible one}}}رible one puts Dan\n"
]
}
],
@ -1322,7 +1320,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
"outputId": "01721809-ace1-4a7a-ab54-8fad2e8f54a6"
"outputId": "351029ce-b4c0-4d39-8e0e-7e7f44d25647"
},
"outputs": [
{
@ -1457,7 +1455,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "240987e8-a023-462e-9376-9edfb27559ec",
"outputId": "59830005-42af-406b-c836-38a8f2d7b961"
"outputId": "3fa7a77a-6203-4d8a-bdaa-afce1f504adf"
},
"outputs": [
{
@ -1465,7 +1463,7 @@
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort has been made to ensure that the information contained in this website is accurate and up to date. However, the information is provided without any warranty\n"
" Every effort has been made to ensure that the information contained in this website is correct and up to date and accurate at the time of publication\n"
]
}
],
@ -1475,7 +1473,7 @@
"token_ids = generate(\n",
" model=model,\n",
" idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
" max_new_tokens=30,\n",
" max_new_tokens=25,\n",
" context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
" top_k=1,\n",
" temperature=0.\n",
@ -1496,14 +1494,14 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 31,
"id": "nbvAV7vaz6yc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nbvAV7vaz6yc",
"outputId": "faa930dc-0db2-4095-b395-f97baef08903"
"outputId": "bd4cae4d-5d5f-4f64-ea37-b979ef2c86bb"
},
"outputs": [
{
@ -1512,7 +1510,7 @@
"text": [
"Output text:\n",
" What do llamas eat?\n",
"Llamas are herbivores, which means they eat plants. They eat grass, leaves, and hay.\n"
"Llamas are herbivores, which means they eat grass, leaves, grasses, and they eat grass\n"
]
}
],

View File

@ -0,0 +1 @@
transformers>=4.44.2

View File

@ -0,0 +1,114 @@
import io
import os
import sys
import types
import nbformat
import torch
import pytest
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
# File for internal use (unit tests)
@pytest.fixture(scope="module")
def notebook():
def import_definitions_from_notebook(fullname, names):
# Get the directory of the current test file
current_dir = os.path.dirname(__file__)
path = os.path.join(current_dir, "..", fullname + ".ipynb")
path = os.path.normpath(path)
# Load the notebook
if not os.path.exists(path):
raise FileNotFoundError(f"Notebook file not found at: {path}")
with io.open(path, "r", encoding="utf-8") as f:
nb = nbformat.read(f, as_version=4)
# Create a module to store the imported functions and classes
mod = types.ModuleType(fullname)
sys.modules[fullname] = mod
# Go through the notebook cells and only execute function or class definitions
for cell in nb.cells:
if cell.cell_type == "code":
cell_code = cell.source
for name in names:
# Check for function or class definitions
if f"def {name}" in cell_code or f"class {name}" in cell_code:
exec(cell_code, mod.__dict__)
return mod
# Specify the notebook name and functions/classes to import
fullname = "converting-gpt-to-llama2"
names = ["precompute_rope_params", "compute_rope", "SiLU", "RMSNorm"]
# Import the required functions and classes from the notebook
return import_definitions_from_notebook(fullname, names)
@pytest.fixture(autouse=True)
def set_seed():
torch.manual_seed(123)
def test_rope(notebook):
# Settings
batch_size = 1
context_len = 5
num_heads = 4
head_dim = 16
# Instantiate RoPE parameters
cos, sin = notebook.precompute_rope_params(head_dim=head_dim, context_length=context_len)
# Dummy query and key tensors
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
# Apply rotary position embeddings
queries_rot = notebook.compute_rope(queries, cos, sin)
keys_rot = notebook.compute_rope(keys, cos, sin)
class RoPEConfig:
rope_type = "default"
rope_scaling = None
factor = 1.0
dim: int = head_dim
rope_theta = 10000
max_position_embeddings: int = 4096
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
torch.testing.assert_close(sin, ref_sin.squeeze(0))
torch.testing.assert_close(cos, ref_cos.squeeze(0))
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)
def test_silu(notebook):
example_batch = torch.randn(2, 3, 4)
silu = notebook.SiLU()
assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
def test_rmsnorm(notebook):
example_batch = torch.randn(2, 3, 4)
rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1])
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)
assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))