mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-28 18:40:01 +00:00
Auto download DPO dataset if not already available in path (#479)
* Auto download DPO dataset if not already available in path * update tests to account for latest HF transformers release in unit tests * pep 8
This commit is contained in:
parent
05f2a398b8
commit
992f3068d1
@ -1,74 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "40d2405d-ee10-44ad-b20e-cf32078f926a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"True | head dim: 1, tensor([]), tensor([])\n",
|
|
||||||
"True | head dim: 2, tensor([1.]), tensor([1.])\n",
|
|
||||||
"True | head dim: 3, tensor([1.]), tensor([1.])\n",
|
|
||||||
"True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n",
|
|
||||||
"False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n",
|
|
||||||
"True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n",
|
|
||||||
"False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n",
|
|
||||||
"True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n",
|
|
||||||
"False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n",
|
|
||||||
"True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n",
|
|
||||||
"False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"\n",
|
|
||||||
"theta_base = 10_000\n",
|
|
||||||
"\n",
|
|
||||||
"for head_dim in range(1, 12):\n",
|
|
||||||
"\n",
|
|
||||||
" before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
|
|
||||||
" after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
|
|
||||||
" \n",
|
|
||||||
" s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n",
|
|
||||||
" print(s)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "0abfbf38-93a4-4994-8e7e-a543477268a8",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.10.6"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
@ -10,14 +10,20 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
import nbformat
|
import nbformat
|
||||||
|
from packaging import version
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
|
import transformers
|
||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
|
transformers_version = transformers.__version__
|
||||||
|
|
||||||
|
# LitGPT code function `litgpt_build_rope_cache` from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
|
||||||
# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
|
# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
|
||||||
|
|
||||||
|
|
||||||
def litgpt_build_rope_cache(
|
def litgpt_build_rope_cache(
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
n_elem: int,
|
n_elem: int,
|
||||||
@ -143,6 +149,7 @@ def test_rope_llama2(notebook):
|
|||||||
context_len = 4096
|
context_len = 4096
|
||||||
num_heads = 4
|
num_heads = 4
|
||||||
head_dim = 16
|
head_dim = 16
|
||||||
|
theta_base = 10_000
|
||||||
|
|
||||||
# Instantiate RoPE parameters
|
# Instantiate RoPE parameters
|
||||||
cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len)
|
cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len)
|
||||||
@ -156,11 +163,24 @@ def test_rope_llama2(notebook):
|
|||||||
keys_rot = this_nb.compute_rope(keys, cos, sin)
|
keys_rot = this_nb.compute_rope(keys, cos, sin)
|
||||||
|
|
||||||
# Generate reference RoPE via HF
|
# Generate reference RoPE via HF
|
||||||
rot_emb = LlamaRotaryEmbedding(
|
|
||||||
dim=head_dim,
|
if version.parse(transformers_version) < version.parse("4.48"):
|
||||||
max_position_embeddings=context_len,
|
rot_emb = LlamaRotaryEmbedding(
|
||||||
base=10_000
|
dim=head_dim,
|
||||||
)
|
max_position_embeddings=context_len,
|
||||||
|
base=theta_base
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
class RoPEConfig:
|
||||||
|
dim: int = head_dim
|
||||||
|
rope_theta = theta_base
|
||||||
|
max_position_embeddings: int = 8192
|
||||||
|
hidden_size = head_dim * num_heads
|
||||||
|
num_attention_heads = num_heads
|
||||||
|
|
||||||
|
config = RoPEConfig()
|
||||||
|
rot_emb = LlamaRotaryEmbedding(config=config)
|
||||||
|
|
||||||
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
||||||
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
||||||
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
||||||
@ -209,11 +229,22 @@ def test_rope_llama3(notebook):
|
|||||||
keys_rot = nb1.compute_rope(keys, cos, sin)
|
keys_rot = nb1.compute_rope(keys, cos, sin)
|
||||||
|
|
||||||
# Generate reference RoPE via HF
|
# Generate reference RoPE via HF
|
||||||
rot_emb = LlamaRotaryEmbedding(
|
if version.parse(transformers_version) < version.parse("4.48"):
|
||||||
dim=head_dim,
|
rot_emb = LlamaRotaryEmbedding(
|
||||||
max_position_embeddings=context_len,
|
dim=head_dim,
|
||||||
base=theta_base
|
max_position_embeddings=context_len,
|
||||||
)
|
base=theta_base
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
class RoPEConfig:
|
||||||
|
dim: int = head_dim
|
||||||
|
rope_theta = theta_base
|
||||||
|
max_position_embeddings: int = 8192
|
||||||
|
hidden_size = head_dim * num_heads
|
||||||
|
num_attention_heads = num_heads
|
||||||
|
|
||||||
|
config = RoPEConfig()
|
||||||
|
rot_emb = LlamaRotaryEmbedding(config=config)
|
||||||
|
|
||||||
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
||||||
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
||||||
|
@ -230,13 +230,34 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import json\n",
|
"import json\n",
|
||||||
|
"import os\n",
|
||||||
|
"import urllib\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def download_and_load_file(file_path, url):\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(file_path):\n",
|
||||||
|
" with urllib.request.urlopen(url) as response:\n",
|
||||||
|
" text_data = response.read().decode(\"utf-8\")\n",
|
||||||
|
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
|
||||||
|
" file.write(text_data)\n",
|
||||||
|
" else:\n",
|
||||||
|
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
||||||
|
" text_data = file.read()\n",
|
||||||
|
"\n",
|
||||||
|
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
||||||
|
" data = json.load(file)\n",
|
||||||
|
"\n",
|
||||||
|
" return data\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"file_path = \"instruction-data-with-preference.json\"\n",
|
"file_path = \"instruction-data-with-preference.json\"\n",
|
||||||
|
"url = (\n",
|
||||||
|
" \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch\"\n",
|
||||||
|
" \"/main/ch07/04_preference-tuning-with-dpo/instruction-data-with-preference.json\"\n",
|
||||||
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
"data = download_and_load_file(file_path, url)\n",
|
||||||
" data = json.load(file)\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"Number of entries:\", len(data))"
|
"print(\"Number of entries:\", len(data))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -1546,7 +1567,6 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"import shutil\n",
|
"import shutil\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user