mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-26 15:29:25 +00:00 
			
		
		
		
	Auto download DPO dataset if not already available in path (#479)
* Auto download DPO dataset if not already available in path * update tests to account for latest HF transformers release in unit tests * pep 8
This commit is contained in:
		
							parent
							
								
									05f2a398b8
								
							
						
					
					
						commit
						992f3068d1
					
				| @ -1,74 +0,0 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 9, | ||||
|    "id": "40d2405d-ee10-44ad-b20e-cf32078f926a", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "True | head dim: 1, tensor([]), tensor([])\n", | ||||
|       "True | head dim: 2, tensor([1.]), tensor([1.])\n", | ||||
|       "True | head dim: 3, tensor([1.]), tensor([1.])\n", | ||||
|       "True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n", | ||||
|       "False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n", | ||||
|       "True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n", | ||||
|       "False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n", | ||||
|       "True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n", | ||||
|       "False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n", | ||||
|       "True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n", | ||||
|       "False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "import torch\n", | ||||
|     "\n", | ||||
|     "theta_base = 10_000\n", | ||||
|     "\n", | ||||
|     "for head_dim in range(1, 12):\n", | ||||
|     "\n", | ||||
|     "    before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", | ||||
|     "    after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n", | ||||
|     "    \n", | ||||
|     "    s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n", | ||||
|     "    print(s)\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "0abfbf38-93a4-4994-8e7e-a543477268a8", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3 (ipykernel)", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.10.6" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 5 | ||||
| } | ||||
| @ -10,14 +10,20 @@ import os | ||||
| import sys | ||||
| import types | ||||
| import nbformat | ||||
| from packaging import version | ||||
| from typing import Optional, Tuple | ||||
| import torch | ||||
| import pytest | ||||
| import transformers | ||||
| from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb | ||||
| 
 | ||||
| 
 | ||||
| # LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py | ||||
| transformers_version = transformers.__version__ | ||||
| 
 | ||||
| # LitGPT code function `litgpt_build_rope_cache` from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py | ||||
| # LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE | ||||
| 
 | ||||
| 
 | ||||
| def litgpt_build_rope_cache( | ||||
|     seq_len: int, | ||||
|     n_elem: int, | ||||
| @ -143,6 +149,7 @@ def test_rope_llama2(notebook): | ||||
|     context_len = 4096 | ||||
|     num_heads = 4 | ||||
|     head_dim = 16 | ||||
|     theta_base = 10_000 | ||||
| 
 | ||||
|     # Instantiate RoPE parameters | ||||
|     cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len) | ||||
| @ -156,11 +163,24 @@ def test_rope_llama2(notebook): | ||||
|     keys_rot = this_nb.compute_rope(keys, cos, sin) | ||||
| 
 | ||||
|     # Generate reference RoPE via HF | ||||
| 
 | ||||
|     if version.parse(transformers_version) < version.parse("4.48"): | ||||
|         rot_emb = LlamaRotaryEmbedding( | ||||
|             dim=head_dim, | ||||
|             max_position_embeddings=context_len, | ||||
|         base=10_000 | ||||
|             base=theta_base | ||||
|         ) | ||||
|     else: | ||||
|         class RoPEConfig: | ||||
|             dim: int = head_dim | ||||
|             rope_theta = theta_base | ||||
|             max_position_embeddings: int = 8192 | ||||
|             hidden_size = head_dim * num_heads | ||||
|             num_attention_heads = num_heads | ||||
| 
 | ||||
|         config = RoPEConfig() | ||||
|         rot_emb = LlamaRotaryEmbedding(config=config) | ||||
| 
 | ||||
|     position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) | ||||
|     ref_cos, ref_sin = rot_emb(queries, position_ids) | ||||
|     ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) | ||||
| @ -209,11 +229,22 @@ def test_rope_llama3(notebook): | ||||
|     keys_rot = nb1.compute_rope(keys, cos, sin) | ||||
| 
 | ||||
|     # Generate reference RoPE via HF | ||||
|     if version.parse(transformers_version) < version.parse("4.48"): | ||||
|         rot_emb = LlamaRotaryEmbedding( | ||||
|             dim=head_dim, | ||||
|             max_position_embeddings=context_len, | ||||
|             base=theta_base | ||||
|         ) | ||||
|     else: | ||||
|         class RoPEConfig: | ||||
|             dim: int = head_dim | ||||
|             rope_theta = theta_base | ||||
|             max_position_embeddings: int = 8192 | ||||
|             hidden_size = head_dim * num_heads | ||||
|             num_attention_heads = num_heads | ||||
| 
 | ||||
|         config = RoPEConfig() | ||||
|         rot_emb = LlamaRotaryEmbedding(config=config) | ||||
| 
 | ||||
|     position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) | ||||
|     ref_cos, ref_sin = rot_emb(queries, position_ids) | ||||
|  | ||||
| @ -230,13 +230,34 @@ | ||||
|    ], | ||||
|    "source": [ | ||||
|     "import json\n", | ||||
|     "import os\n", | ||||
|     "import urllib\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "def download_and_load_file(file_path, url):\n", | ||||
|     "\n", | ||||
|     "    if not os.path.exists(file_path):\n", | ||||
|     "        with urllib.request.urlopen(url) as response:\n", | ||||
|     "            text_data = response.read().decode(\"utf-8\")\n", | ||||
|     "        with open(file_path, \"w\", encoding=\"utf-8\") as file:\n", | ||||
|     "            file.write(text_data)\n", | ||||
|     "    else:\n", | ||||
|     "        with open(file_path, \"r\", encoding=\"utf-8\") as file:\n", | ||||
|     "            text_data = file.read()\n", | ||||
|     "\n", | ||||
|     "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n", | ||||
|     "        data = json.load(file)\n", | ||||
|     "\n", | ||||
|     "    return data\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "file_path = \"instruction-data-with-preference.json\"\n", | ||||
|     "url = (\n", | ||||
|     "    \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch\"\n", | ||||
|     "    \"/main/ch07/04_preference-tuning-with-dpo/instruction-data-with-preference.json\"\n", | ||||
|     ")\n", | ||||
|     "\n", | ||||
|     "with open(file_path, \"r\", encoding=\"utf-8\") as file:\n", | ||||
|     "    data = json.load(file)\n", | ||||
|     "\n", | ||||
|     "data = download_and_load_file(file_path, url)\n", | ||||
|     "print(\"Number of entries:\", len(data))" | ||||
|    ] | ||||
|   }, | ||||
| @ -1546,7 +1567,6 @@ | ||||
|    }, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import os\n", | ||||
|     "from pathlib import Path\n", | ||||
|     "import shutil\n", | ||||
|     "\n", | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian Raschka
						Sebastian Raschka