mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	RoPE updates (#412)
* RoPE updates * Apply suggestions from code review * updates * updates * updates
This commit is contained in:
		
							parent
							
								
									4f9c9fb703
								
							
						
					
					
						commit
						7cd6a670ed
					
				@ -426,7 +426,7 @@
 | 
			
		||||
    "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    # Compute the inverse frequencies\n",
 | 
			
		||||
    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
 | 
			
		||||
    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    # Generate position indices\n",
 | 
			
		||||
    "    positions = torch.arange(context_length)\n",
 | 
			
		||||
@ -493,8 +493,8 @@
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# Dummy query and key tensors\n",
 | 
			
		||||
    "torch.manual_seed(123)\n",
 | 
			
		||||
    "queries = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
 | 
			
		||||
    "keys = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
 | 
			
		||||
    "queries = torch.randn(batch_size, num_heads, context_len, head_dim)\n",
 | 
			
		||||
    "keys = torch.randn(batch_size, num_heads, context_len, head_dim)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# Apply rotary position embeddings\n",
 | 
			
		||||
    "queries_rot = compute_rope(queries, cos, sin)\n",
 | 
			
		||||
@ -1691,7 +1691,7 @@
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.11.4"
 | 
			
		||||
   "version": "3.10.6"
 | 
			
		||||
  },
 | 
			
		||||
  "widgets": {
 | 
			
		||||
   "application/vnd.jupyter.widget-state+json": {
 | 
			
		||||
 | 
			
		||||
@ -278,7 +278,7 @@
 | 
			
		||||
    "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    # Compute the inverse frequencies\n",
 | 
			
		||||
    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
 | 
			
		||||
    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    ################################ NEW ###############################################\n",
 | 
			
		||||
    "    # Frequency adjustments\n",
 | 
			
		||||
@ -383,8 +383,8 @@
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# Dummy query and key tensors\n",
 | 
			
		||||
    "torch.manual_seed(123)\n",
 | 
			
		||||
    "queries = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n",
 | 
			
		||||
    "keys = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n",
 | 
			
		||||
    "queries = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n",
 | 
			
		||||
    "keys = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# Apply rotary position embeddings\n",
 | 
			
		||||
    "queries_rot = compute_rope(queries, cos, sin)\n",
 | 
			
		||||
@ -2701,7 +2701,7 @@
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.11.4"
 | 
			
		||||
   "version": "3.10.6"
 | 
			
		||||
  },
 | 
			
		||||
  "widgets": {
 | 
			
		||||
   "application/vnd.jupyter.widget-state+json": {
 | 
			
		||||
 | 
			
		||||
@ -133,7 +133,7 @@
 | 
			
		||||
    "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    # Compute the inverse frequencies\n",
 | 
			
		||||
    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
 | 
			
		||||
    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    # Frequency adjustments\n",
 | 
			
		||||
    "    if freq_config is not None:\n",
 | 
			
		||||
@ -1061,7 +1061,7 @@
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.11.4"
 | 
			
		||||
   "version": "3.10.6"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										74
									
								
								ch05/07_gpt_to_llama/tests/Untitled.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								ch05/07_gpt_to_llama/tests/Untitled.ipynb
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,74 @@
 | 
			
		||||
{
 | 
			
		||||
 "cells": [
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 9,
 | 
			
		||||
   "id": "40d2405d-ee10-44ad-b20e-cf32078f926a",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "True | head dim: 1, tensor([]), tensor([])\n",
 | 
			
		||||
      "True | head dim: 2, tensor([1.]), tensor([1.])\n",
 | 
			
		||||
      "True | head dim: 3, tensor([1.]), tensor([1.])\n",
 | 
			
		||||
      "True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n",
 | 
			
		||||
      "False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n",
 | 
			
		||||
      "True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n",
 | 
			
		||||
      "False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n",
 | 
			
		||||
      "True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n",
 | 
			
		||||
      "False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n",
 | 
			
		||||
      "True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n",
 | 
			
		||||
      "False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import torch\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "theta_base = 10_000\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "for head_dim in range(1, 12):\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
 | 
			
		||||
    "    after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n",
 | 
			
		||||
    "    print(s)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "id": "0abfbf38-93a4-4994-8e7e-a543477268a8",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": []
 | 
			
		||||
  }
 | 
			
		||||
 ],
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "display_name": "Python 3 (ipykernel)",
 | 
			
		||||
   "language": "python",
 | 
			
		||||
   "name": "python3"
 | 
			
		||||
  },
 | 
			
		||||
  "language_info": {
 | 
			
		||||
   "codemirror_mode": {
 | 
			
		||||
    "name": "ipython",
 | 
			
		||||
    "version": 3
 | 
			
		||||
   },
 | 
			
		||||
   "file_extension": ".py",
 | 
			
		||||
   "mimetype": "text/x-python",
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.10.6"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 "nbformat_minor": 5
 | 
			
		||||
}
 | 
			
		||||
@ -1 +1,2 @@
 | 
			
		||||
transformers>=4.44.2
 | 
			
		||||
transformers>=4.44.2
 | 
			
		||||
litgpt>=0.5.0
 | 
			
		||||
@ -10,11 +10,82 @@ import os
 | 
			
		||||
import sys
 | 
			
		||||
import types
 | 
			
		||||
import nbformat
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
import torch
 | 
			
		||||
import pytest
 | 
			
		||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
 | 
			
		||||
# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
 | 
			
		||||
def litgpt_build_rope_cache(
 | 
			
		||||
    seq_len: int,
 | 
			
		||||
    n_elem: int,
 | 
			
		||||
    device: Optional[torch.device] = None,
 | 
			
		||||
    base: int = 10000,
 | 
			
		||||
    condense_ratio: int = 1,
 | 
			
		||||
    extra_config: Optional[dict] = None,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Enhanced Transformer with Rotary Position Embedding.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        seq_len (int): Sequence length.
 | 
			
		||||
        n_elem (int): Number of elements (head dimension).
 | 
			
		||||
        device (torch.device, optional): Device for tensor allocations.
 | 
			
		||||
        base (int, optional): Base for computing inverse frequencies.
 | 
			
		||||
        condense_ratio (int, optional): Ratio to condense the position indices.
 | 
			
		||||
        extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Compute the inverse frequencies theta
 | 
			
		||||
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
 | 
			
		||||
 | 
			
		||||
    if extra_config is not None:
 | 
			
		||||
        orig_context_len = extra_config["original_max_seq_len"]
 | 
			
		||||
        factor = extra_config["factor"]
 | 
			
		||||
        low_freq_factor = extra_config["low_freq_factor"]
 | 
			
		||||
        high_freq_factor = extra_config["high_freq_factor"]
 | 
			
		||||
 | 
			
		||||
        wavelen = 2 * torch.pi / theta
 | 
			
		||||
        ratio = orig_context_len / wavelen
 | 
			
		||||
        smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
 | 
			
		||||
        smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)
 | 
			
		||||
 | 
			
		||||
        # Compute adjusted_theta without masked indexing
 | 
			
		||||
        adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta
 | 
			
		||||
        theta = adjusted_theta
 | 
			
		||||
 | 
			
		||||
    # Create position indices `[0, 1, ..., seq_len - 1]`
 | 
			
		||||
    seq_idx = torch.arange(seq_len, device=device) / condense_ratio
 | 
			
		||||
 | 
			
		||||
    # Calculate the product of position index and $\theta_i$
 | 
			
		||||
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
 | 
			
		||||
 | 
			
		||||
    return torch.cos(idx_theta), torch.sin(idx_theta)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
 | 
			
		||||
# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
 | 
			
		||||
def litgpt_apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    head_size = x.size(-1)
 | 
			
		||||
    x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
 | 
			
		||||
    x2 = x[..., head_size // 2:]  # (B, nh, T, hs/2)
 | 
			
		||||
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
 | 
			
		||||
    if cos.dim() > 1:
 | 
			
		||||
        # batch dimensions must align
 | 
			
		||||
        # sin/cos are (B, T, hs) so we unsqeeze -3 for nh
 | 
			
		||||
        # we count from back because all of apply_rope does
 | 
			
		||||
        cos = cos.unsqueeze(-3)
 | 
			
		||||
        sin = sin.unsqueeze(-3)
 | 
			
		||||
 | 
			
		||||
    roped = (x * cos) + (rotated * sin)
 | 
			
		||||
    return roped.to(dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(scope="module")
 | 
			
		||||
def notebook():
 | 
			
		||||
    def import_definitions_from_notebook(notebooks):
 | 
			
		||||
@ -84,21 +155,30 @@ def test_rope_llama2(notebook):
 | 
			
		||||
    queries_rot = this_nb.compute_rope(queries, cos, sin)
 | 
			
		||||
    keys_rot = this_nb.compute_rope(keys, cos, sin)
 | 
			
		||||
 | 
			
		||||
    # Generate reference RoPE via HF
 | 
			
		||||
    rot_emb = LlamaRotaryEmbedding(
 | 
			
		||||
        dim=head_dim,
 | 
			
		||||
        max_position_embeddings=context_len,
 | 
			
		||||
        base=10_000
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
 | 
			
		||||
    ref_cos, ref_sin = rot_emb(queries, position_ids)
 | 
			
		||||
    ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
 | 
			
		||||
 | 
			
		||||
    torch.testing.assert_close(sin, ref_sin.squeeze(0))
 | 
			
		||||
    torch.testing.assert_close(cos, ref_cos.squeeze(0))
 | 
			
		||||
    torch.testing.assert_close(keys_rot, ref_keys_rot)
 | 
			
		||||
    torch.testing.assert_close(queries_rot, ref_queries_rot)
 | 
			
		||||
 | 
			
		||||
    # Generate reference RoPE via LitGPT
 | 
			
		||||
    litgpt_cos, litgpt_sin = litgpt_build_rope_cache(context_len, n_elem=head_dim, base=10_000)
 | 
			
		||||
    litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin)
 | 
			
		||||
    litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)
 | 
			
		||||
 | 
			
		||||
    torch.testing.assert_close(sin, litgpt_sin)
 | 
			
		||||
    torch.testing.assert_close(cos, litgpt_cos)
 | 
			
		||||
    torch.testing.assert_close(keys_rot, litgpt_keys_rot)
 | 
			
		||||
    torch.testing.assert_close(queries_rot, litgpt_queries_rot)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_rope_llama3(notebook):
 | 
			
		||||
 | 
			
		||||
@ -128,6 +208,7 @@ def test_rope_llama3(notebook):
 | 
			
		||||
    queries_rot = nb1.compute_rope(queries, cos, sin)
 | 
			
		||||
    keys_rot = nb1.compute_rope(keys, cos, sin)
 | 
			
		||||
 | 
			
		||||
    # Generate reference RoPE via HF
 | 
			
		||||
    rot_emb = LlamaRotaryEmbedding(
 | 
			
		||||
        dim=head_dim,
 | 
			
		||||
        max_position_embeddings=context_len,
 | 
			
		||||
@ -143,6 +224,16 @@ def test_rope_llama3(notebook):
 | 
			
		||||
    torch.testing.assert_close(keys_rot, ref_keys_rot)
 | 
			
		||||
    torch.testing.assert_close(queries_rot, ref_queries_rot)
 | 
			
		||||
 | 
			
		||||
    # Generate reference RoPE via LitGPT
 | 
			
		||||
    litgpt_cos, litgpt_sin = litgpt_build_rope_cache(context_len, n_elem=head_dim, base=theta_base)
 | 
			
		||||
    litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin)
 | 
			
		||||
    litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)
 | 
			
		||||
 | 
			
		||||
    torch.testing.assert_close(sin, litgpt_sin)
 | 
			
		||||
    torch.testing.assert_close(cos, litgpt_cos)
 | 
			
		||||
    torch.testing.assert_close(keys_rot, litgpt_keys_rot)
 | 
			
		||||
    torch.testing.assert_close(queries_rot, litgpt_queries_rot)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_rope_llama3_12(notebook):
 | 
			
		||||
 | 
			
		||||
@ -180,6 +271,7 @@ def test_rope_llama3_12(notebook):
 | 
			
		||||
    queries_rot = nb1.compute_rope(queries, cos, sin)
 | 
			
		||||
    keys_rot = nb1.compute_rope(keys, cos, sin)
 | 
			
		||||
 | 
			
		||||
    # Generate reference RoPE via HF
 | 
			
		||||
    hf_rope_params = {
 | 
			
		||||
        "factor": 8.0,
 | 
			
		||||
        "low_freq_factor": 1.0,
 | 
			
		||||
@ -210,6 +302,28 @@ def test_rope_llama3_12(notebook):
 | 
			
		||||
    torch.testing.assert_close(keys_rot, ref_keys_rot)
 | 
			
		||||
    torch.testing.assert_close(queries_rot, ref_queries_rot)
 | 
			
		||||
 | 
			
		||||
    # Generate reference RoPE via LitGPT
 | 
			
		||||
    litgpt_rope_config = {
 | 
			
		||||
        "factor": 8.0,
 | 
			
		||||
        "low_freq_factor": 1.0,
 | 
			
		||||
        "high_freq_factor": 4.0,
 | 
			
		||||
        "original_max_seq_len": 8192
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    litgpt_cos, litgpt_sin = litgpt_build_rope_cache(
 | 
			
		||||
        context_len,
 | 
			
		||||
        n_elem=head_dim,
 | 
			
		||||
        base=rope_theta,
 | 
			
		||||
        extra_config=litgpt_rope_config
 | 
			
		||||
    )
 | 
			
		||||
    litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin)
 | 
			
		||||
    litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)
 | 
			
		||||
 | 
			
		||||
    torch.testing.assert_close(sin, litgpt_sin)
 | 
			
		||||
    torch.testing.assert_close(cos, litgpt_cos)
 | 
			
		||||
    torch.testing.assert_close(keys_rot, litgpt_keys_rot)
 | 
			
		||||
    torch.testing.assert_close(queries_rot, litgpt_queries_rot)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_silu(notebook):
 | 
			
		||||
    example_batch = torch.randn(2, 3, 4)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user