diff --git a/ch05/07_gpt_to_llama/tests/Untitled.ipynb b/ch05/07_gpt_to_llama/tests/Untitled.ipynb deleted file mode 100644 index 1375a9e..0000000 --- a/ch05/07_gpt_to_llama/tests/Untitled.ipynb +++ /dev/null @@ -1,74 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 9, - "id": "40d2405d-ee10-44ad-b20e-cf32078f926a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True | head dim: 1, tensor([]), tensor([])\n", - "True | head dim: 2, tensor([1.]), tensor([1.])\n", - "True | head dim: 3, tensor([1.]), tensor([1.])\n", - "True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n", - "False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n", - "True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n", - "False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n", - "True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n", - "False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n", - "True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n", - "False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n" - ] - } - ], - "source": [ - "import torch\n", - "\n", - "theta_base = 10_000\n", - "\n", - "for head_dim in range(1, 12):\n", - "\n", - " before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", - " after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n", - " \n", - " s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n", - " print(s)\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0abfbf38-93a4-4994-8e7e-a543477268a8", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/ch05/07_gpt_to_llama/tests/tests.py b/ch05/07_gpt_to_llama/tests/tests.py index 395f9ec..22e00e9 100644 --- a/ch05/07_gpt_to_llama/tests/tests.py +++ b/ch05/07_gpt_to_llama/tests/tests.py @@ -10,14 +10,20 @@ import os import sys import types import nbformat +from packaging import version from typing import Optional, Tuple import torch import pytest +import transformers from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb -# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py +transformers_version = transformers.__version__ + +# LitGPT code function `litgpt_build_rope_cache` from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py # LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE + + def litgpt_build_rope_cache( seq_len: int, n_elem: int, @@ -143,6 +149,7 @@ def test_rope_llama2(notebook): context_len = 4096 num_heads = 4 head_dim = 16 + theta_base = 10_000 # Instantiate RoPE parameters cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len) @@ -156,11 +163,24 @@ def test_rope_llama2(notebook): keys_rot = this_nb.compute_rope(keys, cos, sin) # Generate reference RoPE via HF - rot_emb = LlamaRotaryEmbedding( - dim=head_dim, - max_position_embeddings=context_len, - base=10_000 - ) + + if version.parse(transformers_version) < version.parse("4.48"): + rot_emb = LlamaRotaryEmbedding( + dim=head_dim, + max_position_embeddings=context_len, + base=theta_base + ) + else: + class RoPEConfig: + dim: int = head_dim + rope_theta = theta_base + max_position_embeddings: int = 8192 + hidden_size = head_dim * num_heads + num_attention_heads = num_heads + + config = RoPEConfig() + rot_emb = LlamaRotaryEmbedding(config=config) + position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) ref_cos, ref_sin = rot_emb(queries, position_ids) ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) @@ -209,11 +229,22 @@ def test_rope_llama3(notebook): keys_rot = nb1.compute_rope(keys, cos, sin) # Generate reference RoPE via HF - rot_emb = LlamaRotaryEmbedding( - dim=head_dim, - max_position_embeddings=context_len, - base=theta_base - ) + if version.parse(transformers_version) < version.parse("4.48"): + rot_emb = LlamaRotaryEmbedding( + dim=head_dim, + max_position_embeddings=context_len, + base=theta_base + ) + else: + class RoPEConfig: + dim: int = head_dim + rope_theta = theta_base + max_position_embeddings: int = 8192 + hidden_size = head_dim * num_heads + num_attention_heads = num_heads + + config = RoPEConfig() + rot_emb = LlamaRotaryEmbedding(config=config) position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) ref_cos, ref_sin = rot_emb(queries, position_ids) diff --git a/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb b/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb index c2c5d9e..64ca0ac 100644 --- a/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb +++ b/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb @@ -230,13 +230,34 @@ ], "source": [ "import json\n", + "import os\n", + "import urllib\n", + "\n", + "\n", + "def download_and_load_file(file_path, url):\n", + "\n", + " if not os.path.exists(file_path):\n", + " with urllib.request.urlopen(url) as response:\n", + " text_data = response.read().decode(\"utf-8\")\n", + " with open(file_path, \"w\", encoding=\"utf-8\") as file:\n", + " file.write(text_data)\n", + " else:\n", + " with open(file_path, \"r\", encoding=\"utf-8\") as file:\n", + " text_data = file.read()\n", + "\n", + " with open(file_path, \"r\", encoding=\"utf-8\") as file:\n", + " data = json.load(file)\n", + "\n", + " return data\n", "\n", "\n", "file_path = \"instruction-data-with-preference.json\"\n", + "url = (\n", + " \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch\"\n", + " \"/main/ch07/04_preference-tuning-with-dpo/instruction-data-with-preference.json\"\n", + ")\n", "\n", - "with open(file_path, \"r\", encoding=\"utf-8\") as file:\n", - " data = json.load(file)\n", - "\n", + "data = download_and_load_file(file_path, url)\n", "print(\"Number of entries:\", len(data))" ] }, @@ -1546,7 +1567,6 @@ }, "outputs": [], "source": [ - "import os\n", "from pathlib import Path\n", "import shutil\n", "\n",