Auto download DPO dataset if not already available in path (#479)

* Auto download DPO dataset if not already available in path

* update tests to account for latest HF transformers release in unit tests

* pep 8
This commit is contained in:
Sebastian Raschka 2025-01-12 12:27:28 -06:00 committed by GitHub
parent 05f2a398b8
commit 992f3068d1
3 changed files with 66 additions and 89 deletions

View File

@ -1,74 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"id": "40d2405d-ee10-44ad-b20e-cf32078f926a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True | head dim: 1, tensor([]), tensor([])\n",
"True | head dim: 2, tensor([1.]), tensor([1.])\n",
"True | head dim: 3, tensor([1.]), tensor([1.])\n",
"True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n",
"False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n",
"True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n",
"False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n",
"True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n",
"False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n",
"True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n",
"False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n"
]
}
],
"source": [
"import torch\n",
"\n",
"theta_base = 10_000\n",
"\n",
"for head_dim in range(1, 12):\n",
"\n",
" before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
" after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
" \n",
" s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n",
" print(s)\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0abfbf38-93a4-4994-8e7e-a543477268a8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -10,14 +10,20 @@ import os
import sys import sys
import types import types
import nbformat import nbformat
from packaging import version
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import pytest import pytest
import transformers
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py transformers_version = transformers.__version__
# LitGPT code function `litgpt_build_rope_cache` from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE # LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
def litgpt_build_rope_cache( def litgpt_build_rope_cache(
seq_len: int, seq_len: int,
n_elem: int, n_elem: int,
@ -143,6 +149,7 @@ def test_rope_llama2(notebook):
context_len = 4096 context_len = 4096
num_heads = 4 num_heads = 4
head_dim = 16 head_dim = 16
theta_base = 10_000
# Instantiate RoPE parameters # Instantiate RoPE parameters
cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len) cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len)
@ -156,11 +163,24 @@ def test_rope_llama2(notebook):
keys_rot = this_nb.compute_rope(keys, cos, sin) keys_rot = this_nb.compute_rope(keys, cos, sin)
# Generate reference RoPE via HF # Generate reference RoPE via HF
rot_emb = LlamaRotaryEmbedding(
dim=head_dim, if version.parse(transformers_version) < version.parse("4.48"):
max_position_embeddings=context_len, rot_emb = LlamaRotaryEmbedding(
base=10_000 dim=head_dim,
) max_position_embeddings=context_len,
base=theta_base
)
else:
class RoPEConfig:
dim: int = head_dim
rope_theta = theta_base
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids) ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
@ -209,11 +229,22 @@ def test_rope_llama3(notebook):
keys_rot = nb1.compute_rope(keys, cos, sin) keys_rot = nb1.compute_rope(keys, cos, sin)
# Generate reference RoPE via HF # Generate reference RoPE via HF
rot_emb = LlamaRotaryEmbedding( if version.parse(transformers_version) < version.parse("4.48"):
dim=head_dim, rot_emb = LlamaRotaryEmbedding(
max_position_embeddings=context_len, dim=head_dim,
base=theta_base max_position_embeddings=context_len,
) base=theta_base
)
else:
class RoPEConfig:
dim: int = head_dim
rope_theta = theta_base
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids) ref_cos, ref_sin = rot_emb(queries, position_ids)

View File

@ -230,13 +230,34 @@
], ],
"source": [ "source": [
"import json\n", "import json\n",
"import os\n",
"import urllib\n",
"\n",
"\n",
"def download_and_load_file(file_path, url):\n",
"\n",
" if not os.path.exists(file_path):\n",
" with urllib.request.urlopen(url) as response:\n",
" text_data = response.read().decode(\"utf-8\")\n",
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
" file.write(text_data)\n",
" else:\n",
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" text_data = file.read()\n",
"\n",
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" data = json.load(file)\n",
"\n",
" return data\n",
"\n", "\n",
"\n", "\n",
"file_path = \"instruction-data-with-preference.json\"\n", "file_path = \"instruction-data-with-preference.json\"\n",
"url = (\n",
" \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch\"\n",
" \"/main/ch07/04_preference-tuning-with-dpo/instruction-data-with-preference.json\"\n",
")\n",
"\n", "\n",
"with open(file_path, \"r\", encoding=\"utf-8\") as file:\n", "data = download_and_load_file(file_path, url)\n",
" data = json.load(file)\n",
"\n",
"print(\"Number of entries:\", len(data))" "print(\"Number of entries:\", len(data))"
] ]
}, },
@ -1546,7 +1567,6 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"import shutil\n", "import shutil\n",
"\n", "\n",