mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-03 19:30:26 +00:00 
			
		
		
		
	Auto download DPO dataset if not already available in path (#479)
* Auto download DPO dataset if not already available in path * update tests to account for latest HF transformers release in unit tests * pep 8
This commit is contained in:
		
							parent
							
								
									a48f9c7fe2
								
							
						
					
					
						commit
						4bfbcd069d
					
				@ -1,74 +0,0 @@
 | 
			
		||||
{
 | 
			
		||||
 "cells": [
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 9,
 | 
			
		||||
   "id": "40d2405d-ee10-44ad-b20e-cf32078f926a",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "True | head dim: 1, tensor([]), tensor([])\n",
 | 
			
		||||
      "True | head dim: 2, tensor([1.]), tensor([1.])\n",
 | 
			
		||||
      "True | head dim: 3, tensor([1.]), tensor([1.])\n",
 | 
			
		||||
      "True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n",
 | 
			
		||||
      "False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n",
 | 
			
		||||
      "True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n",
 | 
			
		||||
      "False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n",
 | 
			
		||||
      "True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n",
 | 
			
		||||
      "False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n",
 | 
			
		||||
      "True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n",
 | 
			
		||||
      "False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import torch\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "theta_base = 10_000\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "for head_dim in range(1, 12):\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
 | 
			
		||||
    "    after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n",
 | 
			
		||||
    "    print(s)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "id": "0abfbf38-93a4-4994-8e7e-a543477268a8",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": []
 | 
			
		||||
  }
 | 
			
		||||
 ],
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "display_name": "Python 3 (ipykernel)",
 | 
			
		||||
   "language": "python",
 | 
			
		||||
   "name": "python3"
 | 
			
		||||
  },
 | 
			
		||||
  "language_info": {
 | 
			
		||||
   "codemirror_mode": {
 | 
			
		||||
    "name": "ipython",
 | 
			
		||||
    "version": 3
 | 
			
		||||
   },
 | 
			
		||||
   "file_extension": ".py",
 | 
			
		||||
   "mimetype": "text/x-python",
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.10.6"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 "nbformat_minor": 5
 | 
			
		||||
}
 | 
			
		||||
@ -10,14 +10,20 @@ import os
 | 
			
		||||
import sys
 | 
			
		||||
import types
 | 
			
		||||
import nbformat
 | 
			
		||||
from packaging import version
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
import torch
 | 
			
		||||
import pytest
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
 | 
			
		||||
transformers_version = transformers.__version__
 | 
			
		||||
 | 
			
		||||
# LitGPT code function `litgpt_build_rope_cache` from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
 | 
			
		||||
# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def litgpt_build_rope_cache(
 | 
			
		||||
    seq_len: int,
 | 
			
		||||
    n_elem: int,
 | 
			
		||||
@ -143,6 +149,7 @@ def test_rope_llama2(notebook):
 | 
			
		||||
    context_len = 4096
 | 
			
		||||
    num_heads = 4
 | 
			
		||||
    head_dim = 16
 | 
			
		||||
    theta_base = 10_000
 | 
			
		||||
 | 
			
		||||
    # Instantiate RoPE parameters
 | 
			
		||||
    cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len)
 | 
			
		||||
@ -156,11 +163,24 @@ def test_rope_llama2(notebook):
 | 
			
		||||
    keys_rot = this_nb.compute_rope(keys, cos, sin)
 | 
			
		||||
 | 
			
		||||
    # Generate reference RoPE via HF
 | 
			
		||||
    rot_emb = LlamaRotaryEmbedding(
 | 
			
		||||
        dim=head_dim,
 | 
			
		||||
        max_position_embeddings=context_len,
 | 
			
		||||
        base=10_000
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if version.parse(transformers_version) < version.parse("4.48"):
 | 
			
		||||
        rot_emb = LlamaRotaryEmbedding(
 | 
			
		||||
            dim=head_dim,
 | 
			
		||||
            max_position_embeddings=context_len,
 | 
			
		||||
            base=theta_base
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        class RoPEConfig:
 | 
			
		||||
            dim: int = head_dim
 | 
			
		||||
            rope_theta = theta_base
 | 
			
		||||
            max_position_embeddings: int = 8192
 | 
			
		||||
            hidden_size = head_dim * num_heads
 | 
			
		||||
            num_attention_heads = num_heads
 | 
			
		||||
 | 
			
		||||
        config = RoPEConfig()
 | 
			
		||||
        rot_emb = LlamaRotaryEmbedding(config=config)
 | 
			
		||||
 | 
			
		||||
    position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
 | 
			
		||||
    ref_cos, ref_sin = rot_emb(queries, position_ids)
 | 
			
		||||
    ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
 | 
			
		||||
@ -209,11 +229,22 @@ def test_rope_llama3(notebook):
 | 
			
		||||
    keys_rot = nb1.compute_rope(keys, cos, sin)
 | 
			
		||||
 | 
			
		||||
    # Generate reference RoPE via HF
 | 
			
		||||
    rot_emb = LlamaRotaryEmbedding(
 | 
			
		||||
        dim=head_dim,
 | 
			
		||||
        max_position_embeddings=context_len,
 | 
			
		||||
        base=theta_base
 | 
			
		||||
    )
 | 
			
		||||
    if version.parse(transformers_version) < version.parse("4.48"):
 | 
			
		||||
        rot_emb = LlamaRotaryEmbedding(
 | 
			
		||||
            dim=head_dim,
 | 
			
		||||
            max_position_embeddings=context_len,
 | 
			
		||||
            base=theta_base
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        class RoPEConfig:
 | 
			
		||||
            dim: int = head_dim
 | 
			
		||||
            rope_theta = theta_base
 | 
			
		||||
            max_position_embeddings: int = 8192
 | 
			
		||||
            hidden_size = head_dim * num_heads
 | 
			
		||||
            num_attention_heads = num_heads
 | 
			
		||||
 | 
			
		||||
        config = RoPEConfig()
 | 
			
		||||
        rot_emb = LlamaRotaryEmbedding(config=config)
 | 
			
		||||
 | 
			
		||||
    position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
 | 
			
		||||
    ref_cos, ref_sin = rot_emb(queries, position_ids)
 | 
			
		||||
 | 
			
		||||
@ -230,13 +230,34 @@
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import json\n",
 | 
			
		||||
    "import os\n",
 | 
			
		||||
    "import urllib\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "def download_and_load_file(file_path, url):\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    if not os.path.exists(file_path):\n",
 | 
			
		||||
    "        with urllib.request.urlopen(url) as response:\n",
 | 
			
		||||
    "            text_data = response.read().decode(\"utf-8\")\n",
 | 
			
		||||
    "        with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
 | 
			
		||||
    "            file.write(text_data)\n",
 | 
			
		||||
    "    else:\n",
 | 
			
		||||
    "        with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
 | 
			
		||||
    "            text_data = file.read()\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
 | 
			
		||||
    "        data = json.load(file)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    return data\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "file_path = \"instruction-data-with-preference.json\"\n",
 | 
			
		||||
    "url = (\n",
 | 
			
		||||
    "    \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch\"\n",
 | 
			
		||||
    "    \"/main/ch07/04_preference-tuning-with-dpo/instruction-data-with-preference.json\"\n",
 | 
			
		||||
    ")\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
 | 
			
		||||
    "    data = json.load(file)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "data = download_and_load_file(file_path, url)\n",
 | 
			
		||||
    "print(\"Number of entries:\", len(data))"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
@ -1546,7 +1567,6 @@
 | 
			
		||||
   },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os\n",
 | 
			
		||||
    "from pathlib import Path\n",
 | 
			
		||||
    "import shutil\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user