mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	* Uv workflow improvements * Uv workflow improvements * linter improvements * pytproject.toml fixes * pytproject.toml fixes * pytproject.toml fixes * pytproject.toml fixes * pytproject.toml fixes * pytproject.toml fixes * windows fixes * windows fixes * windows fixes * windows fixes * windows fixes * windows fixes * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix * win32 fix
		
			
				
	
	
		
			6863 lines
		
	
	
		
			215 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			6863 lines
		
	
	
		
			215 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
{
 | 
						||
 "cells": [
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "0_xya1nyDHfY",
 | 
						||
   "metadata": {
 | 
						||
    "id": "0_xya1nyDHfY"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "<table style=\"width:100%\">\n",
 | 
						||
    "<tr>\n",
 | 
						||
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
 | 
						||
    "<font size=\"2\">\n",
 | 
						||
    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
 | 
						||
    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
 | 
						||
    "</font>\n",
 | 
						||
    "</td>\n",
 | 
						||
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
 | 
						||
    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
 | 
						||
    "</td>\n",
 | 
						||
    "</tr>\n",
 | 
						||
    "</table>"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "l62zIRRSBy_R",
 | 
						||
   "metadata": {
 | 
						||
    "id": "l62zIRRSBy_R"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "# Converting Llama 2 to Llama 3.2 From Scratch"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "aFmxTQbwCUMl",
 | 
						||
   "metadata": {
 | 
						||
    "id": "aFmxTQbwCUMl"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- This is a follow-up notebook to [Converting a From-Scratch GPT Architecture to Llama 2](./converting-gpt-to-llama2.ipynb), converting Meta AI's Llama 2 architecture model step by step to Llama 3, Llama 3.1, and Llama 3.2\n",
 | 
						||
    "- The explanations are purposefully kept minimal in this notebook so as not to bloat it unnecessarily and focus on the main code\n",
 | 
						||
    "- For more information about the architectures, please see the Llama 2 and Llama 3 papers\n",
 | 
						||
    " - [Llama 2: Open Foundation and Fine-Tuned Chat Models (2023)](https://arxiv.org/abs/2307.09288)\n",
 | 
						||
    " - [The Llama 3 Herd of Models](https://arxiv.org/abs/2407.21783)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "ohhMKUWvGm9z",
 | 
						||
   "metadata": {
 | 
						||
    "id": "ohhMKUWvGm9z"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/gpt2-to-llama2-llama3.webp?1\">"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 1,
 | 
						||
   "id": "ws0wsUzwLH2k",
 | 
						||
   "metadata": {
 | 
						||
    "id": "ws0wsUzwLH2k"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# pip install -r requirements-extra.txt"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "JBpQwU89ETA1",
 | 
						||
   "metadata": {
 | 
						||
    "id": "JBpQwU89ETA1"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Packages that are being used in this notebook:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 2,
 | 
						||
   "id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
 | 
						||
    "outputId": "e3d3d4b6-ee63-4e28-d794-e8b0bdd931fd"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "blobfile version: 3.0.0\n",
 | 
						||
      "huggingface_hub version: 0.24.7\n",
 | 
						||
      "tiktoken version: 0.8.0\n",
 | 
						||
      "torch version: 2.4.1+cu121\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from importlib.metadata import version\n",
 | 
						||
    "\n",
 | 
						||
    "pkgs = [\n",
 | 
						||
    "    \"blobfile\",         # to download pretrained weights\n",
 | 
						||
    "    \"huggingface_hub\",  # to download pretrained weights\n",
 | 
						||
    "    \"tiktoken\",         # to implement the tokenizer\n",
 | 
						||
    "    \"torch\",            # to implement the model\n",
 | 
						||
    "]\n",
 | 
						||
    "for p in pkgs:\n",
 | 
						||
    "    print(f\"{p} version: {version(p)}\")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "UJJneXpTEg4W",
 | 
						||
   "metadata": {
 | 
						||
    "id": "UJJneXpTEg4W"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "# 1. Convert the Llama model implementation step by step"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "v1zpfX2GHBKa",
 | 
						||
   "metadata": {
 | 
						||
    "id": "v1zpfX2GHBKa"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- If you are new to implementing LLM architectures, I recommend starting with [chapter 4](../../ch04/01_main-chapter-code/ch04.ipynb), which walks you through the implementation of the original GPT architecture step by step\n",
 | 
						||
    "- The [Converting a From-Scratch GPT Architecture to Llama 2](./converting-gpt-to-llama2.ipynb) then implements the Llama-specific components, such as RMSNorm layers, SiLU and SwiGLU activations, RoPE (rotary position embeddings), and the SentencePiece tokenizer\n",
 | 
						||
    "- This notebook takes the Llama 2 architecture and transforms it into Llama 3 architecture by\n",
 | 
						||
    "    1. modifying the rotary embeddings\n",
 | 
						||
    "    2. implementing grouped-query attention\n",
 | 
						||
    "    3. and using a customized version of the GPT-4 tokenizer\n",
 | 
						||
    "- Later, we then load the original Llama 3 weights shared by Meta AI into the architecture"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "c14b9121-abe1-4a46-99b8-acdef71e5b41",
 | 
						||
   "metadata": {
 | 
						||
    "id": "c14b9121-abe1-4a46-99b8-acdef71e5b41"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "## 1.1 Reusing Llama 2 components"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "dgDhJGJ6xR4e",
 | 
						||
   "metadata": {
 | 
						||
    "id": "dgDhJGJ6xR4e"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Llama 2 is actually quite similar to Llama 3, as mentioned above and illustrated in the figure at the top of this notebook\n",
 | 
						||
    "- This means that we can import several building blocks from the [Llama 2 notebook](./converting-gpt-to-llama2.ipynb) using the following code"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 3,
 | 
						||
   "id": "a5bc3948-231b-4f1f-8d41-24ad0b7643d0",
 | 
						||
   "metadata": {
 | 
						||
    "id": "a5bc3948-231b-4f1f-8d41-24ad0b7643d0"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "import os\n",
 | 
						||
    "import sys\n",
 | 
						||
    "import io\n",
 | 
						||
    "import nbformat\n",
 | 
						||
    "import types\n",
 | 
						||
    "\n",
 | 
						||
    "def import_from_notebook():\n",
 | 
						||
    "    def import_definitions_from_notebook(fullname, names):\n",
 | 
						||
    "        current_dir = os.getcwd()\n",
 | 
						||
    "        path = os.path.join(current_dir, fullname + \".ipynb\")\n",
 | 
						||
    "        path = os.path.normpath(path)\n",
 | 
						||
    "\n",
 | 
						||
    "        # Load the notebook\n",
 | 
						||
    "        if not os.path.exists(path):\n",
 | 
						||
    "            raise FileNotFoundError(f\"Notebook file not found at: {path}\")\n",
 | 
						||
    "\n",
 | 
						||
    "        with io.open(path, \"r\", encoding=\"utf-8\") as f:\n",
 | 
						||
    "            nb = nbformat.read(f, as_version=4)\n",
 | 
						||
    "\n",
 | 
						||
    "        # Create a module to store the imported functions and classes\n",
 | 
						||
    "        mod = types.ModuleType(fullname)\n",
 | 
						||
    "        sys.modules[fullname] = mod\n",
 | 
						||
    "\n",
 | 
						||
    "        # Go through the notebook cells and only execute function or class definitions\n",
 | 
						||
    "        for cell in nb.cells:\n",
 | 
						||
    "            if cell.cell_type == \"code\":\n",
 | 
						||
    "                cell_code = cell.source\n",
 | 
						||
    "                for name in names:\n",
 | 
						||
    "                    # Check for function or class definitions\n",
 | 
						||
    "                    if f\"def {name}\" in cell_code or f\"class {name}\" in cell_code:\n",
 | 
						||
    "                        exec(cell_code, mod.__dict__)\n",
 | 
						||
    "        return mod\n",
 | 
						||
    "\n",
 | 
						||
    "    fullname = \"converting-gpt-to-llama2\"\n",
 | 
						||
    "    names = [\"precompute_rope_params\", \"compute_rope\", \"SiLU\", \"FeedForward\", \"RMSNorm\", \"MultiHeadAttention\"]\n",
 | 
						||
    "\n",
 | 
						||
    "    return import_definitions_from_notebook(fullname, names)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 4,
 | 
						||
   "id": "d546032d-fce4-47cf-8d0e-682b78b21c61",
 | 
						||
   "metadata": {
 | 
						||
    "id": "d546032d-fce4-47cf-8d0e-682b78b21c61"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "imported_module = import_from_notebook()\n",
 | 
						||
    "\n",
 | 
						||
    "# We need to redefine precompute_rope_params\n",
 | 
						||
    "# precompute_rope_params = getattr(imported_module, \"precompute_rope_params\", None)\n",
 | 
						||
    "compute_rope = getattr(imported_module, \"compute_rope\", None)\n",
 | 
						||
    "SiLU = getattr(imported_module, \"SiLU\", None)\n",
 | 
						||
    "FeedForward = getattr(imported_module, \"FeedForward\", None)\n",
 | 
						||
    "RMSNorm = getattr(imported_module, \"RMSNorm\", None)\n",
 | 
						||
    "\n",
 | 
						||
    "# MultiHeadAttention only for comparison purposes\n",
 | 
						||
    "MultiHeadAttention = getattr(imported_module, \"MultiHeadAttention\", None)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f",
 | 
						||
   "metadata": {
 | 
						||
    "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "## 1.2 Modified RoPE"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "m9_oDcHCx8VI",
 | 
						||
   "metadata": {
 | 
						||
    "id": "m9_oDcHCx8VI"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Llama 3 uses rotary position embeddings (RoPE) similar to Llama 2 (for a detailed explanation, please see the [RoPE paper](https://arxiv.org/abs/2104.09864))\n",
 | 
						||
    "- There are some subtle differences in the RoPE settings, though\n",
 | 
						||
    " - Llama 3 now supports up to 8,192 tokens, twice as many as Llama 2 (4,096)\n",
 | 
						||
    " - The base value for the so-called RoPE $\\theta$ (see equation below) was increased from 10,000 (Llama 2) to 500,000 (Llama 3) in the following equation (adapted from the [RoPE paper](https://arxiv.org/abs/2104.09864))\n",
 | 
						||
    "\n",
 | 
						||
    "$$\\Theta = \\left\\{\\theta_i = \\text{base}^{\\frac{-2(i-1)}{d}}, i \\in \\left[1, 2, ..., d/2\\right]\\right\\}$$\n",
 | 
						||
    "\n",
 | 
						||
    "- These $\\theta$ values are a set of predefined parameters that are used to determine the rotational angles in the rotary matrix, where $d$ is the dimensionality of the embedding space\n",
 | 
						||
    "- Increasing the base from 10,000 to 500,000 makes the frequencies (or rotation angles) decay more slowly across the dimensions, which means that higher dimensions will be associated with larger angles than before (essentially, it's a decompression of the frequencies)\n",
 | 
						||
    "- In addition, we introduce a `freq_config` section in the code below that adjusts the frequency; however, we won't be needing it in Llama 3 (only Llama 3.1 and Llama 3.2), so we will revisit this `freq_config` later (it's set to `None` and ignored by default)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 5,
 | 
						||
   "id": "6Upl109OOAcu",
 | 
						||
   "metadata": {
 | 
						||
    "id": "6Upl109OOAcu"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "import torch\n",
 | 
						||
    "\n",
 | 
						||
    "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n",
 | 
						||
    "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
 | 
						||
    "\n",
 | 
						||
    "    # Compute the inverse frequencies\n",
 | 
						||
    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
 | 
						||
    "\n",
 | 
						||
    "    ################################ NEW ###############################################\n",
 | 
						||
    "    # Frequency adjustments\n",
 | 
						||
    "    if freq_config is not None:\n",
 | 
						||
    "        low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n",
 | 
						||
    "        high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n",
 | 
						||
    "\n",
 | 
						||
    "        wavelen = 2 * torch.pi / inv_freq\n",
 | 
						||
    "\n",
 | 
						||
    "        inv_freq_llama = torch.where(\n",
 | 
						||
    "            wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n",
 | 
						||
    "        )\n",
 | 
						||
    "\n",
 | 
						||
    "        smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n",
 | 
						||
    "            freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n",
 | 
						||
    "        )\n",
 | 
						||
    "\n",
 | 
						||
    "        smoothed_inv_freq = (\n",
 | 
						||
    "            (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n",
 | 
						||
    "        )\n",
 | 
						||
    "\n",
 | 
						||
    "        is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n",
 | 
						||
    "        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n",
 | 
						||
    "        inv_freq = inv_freq_llama\n",
 | 
						||
    "    ####################################################################################\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "    # Generate position indices\n",
 | 
						||
    "    positions = torch.arange(context_length)\n",
 | 
						||
    "\n",
 | 
						||
    "    # Compute the angles\n",
 | 
						||
    "    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)\n",
 | 
						||
    "\n",
 | 
						||
    "    # Expand angles to match the head_dim\n",
 | 
						||
    "    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)\n",
 | 
						||
    "\n",
 | 
						||
    "    # Precompute sine and cosine\n",
 | 
						||
    "    cos = torch.cos(angles)\n",
 | 
						||
    "    sin = torch.sin(angles)\n",
 | 
						||
    "\n",
 | 
						||
    "    return cos, sin"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "jJBvO0YMJBXR",
 | 
						||
   "metadata": {
 | 
						||
    "id": "jJBvO0YMJBXR"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- To summarize, what's new so far for Llama 3 compared to Llama 2 are the context length and theta base parameter:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 6,
 | 
						||
   "id": "56c37216-e022-4603-be16-f9d3eaeaf4a1",
 | 
						||
   "metadata": {
 | 
						||
    "id": "56c37216-e022-4603-be16-f9d3eaeaf4a1"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# Instantiate RoPE parameters\n",
 | 
						||
    "\n",
 | 
						||
    "llama_2_context_len = 4096\n",
 | 
						||
    "llama_3_context_len = 8192\n",
 | 
						||
    "\n",
 | 
						||
    "llama_2_theta_base = 10_000\n",
 | 
						||
    "llama_3_theta_base = 500_000"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "_V8v6i7MJItU",
 | 
						||
   "metadata": {
 | 
						||
    "id": "_V8v6i7MJItU"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- The usage remains the same as before in Llama 2:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 7,
 | 
						||
   "id": "dae70c8a-eb18-40f9-a2e5-a6af2a57628b",
 | 
						||
   "metadata": {
 | 
						||
    "id": "dae70c8a-eb18-40f9-a2e5-a6af2a57628b"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# Settings\n",
 | 
						||
    "batch_size = 2\n",
 | 
						||
    "num_heads = 4\n",
 | 
						||
    "head_dim = 16\n",
 | 
						||
    "\n",
 | 
						||
    "# Instantiate RoPE parameters\n",
 | 
						||
    "cos, sin = precompute_rope_params(\n",
 | 
						||
    "    head_dim=head_dim,\n",
 | 
						||
    "    theta_base=llama_3_theta_base,\n",
 | 
						||
    "    context_length=llama_3_context_len\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "# Dummy query and key tensors\n",
 | 
						||
    "torch.manual_seed(123)\n",
 | 
						||
    "queries = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n",
 | 
						||
    "keys = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n",
 | 
						||
    "\n",
 | 
						||
    "# Apply rotary position embeddings\n",
 | 
						||
    "queries_rot = compute_rope(queries, cos, sin)\n",
 | 
						||
    "keys_rot = compute_rope(keys, cos, sin)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "cd19b75c-cf25-47b8-a010-6733fc0e9a8a",
 | 
						||
   "metadata": {
 | 
						||
    "id": "cd19b75c-cf25-47b8-a010-6733fc0e9a8a"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "## 1.3 Grouped-query attention"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "111c7d3f-fded-49e8-a617-9fe67b81dddc",
 | 
						||
   "metadata": {
 | 
						||
    "id": "111c7d3f-fded-49e8-a617-9fe67b81dddc"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- In this section, we replace multi-head attention (MHA) with an alternative mechanism called grouped-query attention (GQA)\n",
 | 
						||
    "- In short, one can think of GQA as a more compute- and parameter-efficient version of MHA\n",
 | 
						||
    "- In GQA, we reduce the number of key and value projections by sharing them among multiple attention heads\n",
 | 
						||
    "- Each attention head still has its unique query, but these queries attend to the same group of keys and values\n",
 | 
						||
    "- Below is an illustration of GQA with 2 key-value-groups (kv-groups):\n",
 | 
						||
    "\n",
 | 
						||
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/grouped-query-attention.webp\" width=\"500px\">\n"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "perAYa2R_KW2",
 | 
						||
   "metadata": {
 | 
						||
    "id": "perAYa2R_KW2"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- The main idea behind GQA is to reduce the number of unique query groups that attend to the key-value pairs, reducing the size of some of the matrix multiplications and the number of parameters in MHA without significantly reducing modeling performance\n",
 | 
						||
    "- The GQA code is very similar to MHA (I highlighted the changes below via the \"NEW\" sections)\n",
 | 
						||
    "- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "842aa71a-4659-424e-8830-392bd6ae86af",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 8,
 | 
						||
   "id": "9b12e674-ef08-4dd7-8843-615b65b39c91",
 | 
						||
   "metadata": {
 | 
						||
    "id": "9b12e674-ef08-4dd7-8843-615b65b39c91"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "import torch.nn as nn\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "############################# NEW  #############################\n",
 | 
						||
    "class SharedBuffers:\n",
 | 
						||
    "    _buffers = {}\n",
 | 
						||
    "\n",
 | 
						||
    "    @staticmethod\n",
 | 
						||
    "    def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
 | 
						||
    "        key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
 | 
						||
    "\n",
 | 
						||
    "        if key not in SharedBuffers._buffers:\n",
 | 
						||
    "            # Create or fetch the buffers\n",
 | 
						||
    "            mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
 | 
						||
    "            cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
 | 
						||
    "            if dtype is not None:\n",
 | 
						||
    "                cos = cos.to(dtype)\n",
 | 
						||
    "                sin = sin.to(dtype)\n",
 | 
						||
    "            SharedBuffers._buffers[key] = (mask, cos, sin)\n",
 | 
						||
    "\n",
 | 
						||
    "        return SharedBuffers._buffers[key]\n",
 | 
						||
    "############################# NEW  #############################\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "class GroupedQueryAttention(nn.Module):\n",
 | 
						||
    "    def __init__(\n",
 | 
						||
    "            self, d_in, d_out, context_length, num_heads,\n",
 | 
						||
    "            num_kv_groups,       # NEW\n",
 | 
						||
    "            rope_base=10_000,    # NEW\n",
 | 
						||
    "            rope_config=None,    # NEW\n",
 | 
						||
    "            dtype=None\n",
 | 
						||
    "        ):\n",
 | 
						||
    "        super().__init__()\n",
 | 
						||
    "        assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
 | 
						||
    "        assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"  # NEW\n",
 | 
						||
    "\n",
 | 
						||
    "        self.d_out = d_out\n",
 | 
						||
    "        self.num_heads = num_heads\n",
 | 
						||
    "        self.head_dim = d_out // num_heads\n",
 | 
						||
    "\n",
 | 
						||
    "        ############################# NEW  #############################\n",
 | 
						||
    "        # self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
 | 
						||
    "        # self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
 | 
						||
    "        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
 | 
						||
    "        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
 | 
						||
    "        self.num_kv_groups = num_kv_groups\n",
 | 
						||
    "        self.group_size = num_heads // num_kv_groups\n",
 | 
						||
    "        ################################################################\n",
 | 
						||
    "\n",
 | 
						||
    "        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
 | 
						||
    "        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
 | 
						||
    "\n",
 | 
						||
    "        ############################# NEW  #############################\n",
 | 
						||
    "        # Fetch buffers using SharedBuffers\n",
 | 
						||
    "        mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
 | 
						||
    "        ############################# NEW  #############################\n",
 | 
						||
    "        \n",
 | 
						||
    "        self.register_buffer(\"mask\", mask)\n",
 | 
						||
    "        self.register_buffer(\"cos\", cos)\n",
 | 
						||
    "        self.register_buffer(\"sin\", sin)\n",
 | 
						||
    "\n",
 | 
						||
    "    def forward(self, x):\n",
 | 
						||
    "        b, num_tokens, d_in = x.shape\n",
 | 
						||
    "\n",
 | 
						||
    "        queries = self.W_query(x)  # Shape: (b, num_tokens, d_out)\n",
 | 
						||
    "        keys = self.W_key(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
 | 
						||
    "        values = self.W_value(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
 | 
						||
    "\n",
 | 
						||
    "        # Reshape queries, keys, and values\n",
 | 
						||
    "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
 | 
						||
    "\n",
 | 
						||
    "        ##################### NEW  #####################\n",
 | 
						||
    "        # keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n",
 | 
						||
    "        # values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
 | 
						||
    "        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
 | 
						||
    "        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
 | 
						||
    "        ################################################\n",
 | 
						||
    "\n",
 | 
						||
    "        # Transpose keys, values, and queries\n",
 | 
						||
    "        keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
 | 
						||
    "        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
 | 
						||
    "        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
 | 
						||
    "\n",
 | 
						||
    "        # Apply RoPE\n",
 | 
						||
    "        keys = compute_rope(keys, self.cos, self.sin)\n",
 | 
						||
    "        queries = compute_rope(queries, self.cos, self.sin)\n",
 | 
						||
    "\n",
 | 
						||
    "        ##################### NEW  #####################\n",
 | 
						||
    "        # Expand keys and values to match the number of heads\n",
 | 
						||
    "        # Shape: (b, num_heads, num_tokens, head_dim)\n",
 | 
						||
    "\n",
 | 
						||
    "        keys = keys.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
 | 
						||
    "        values = values.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
 | 
						||
    "        # For example, before repeat_interleave along dim=1 (query groups):\n",
 | 
						||
    "        #   [K1, K2]\n",
 | 
						||
    "        # After repeat_interleave (each query group is repeated group_size times):\n",
 | 
						||
    "        #   [K1, K1, K2, K2]\n",
 | 
						||
    "        # If we used regular repeat instead of repeat_interleave, we'd get:\n",
 | 
						||
    "        #   [K1, K2, K1, K2]\n",
 | 
						||
    "        ################################################\n",
 | 
						||
    "\n",
 | 
						||
    "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
 | 
						||
    "        # Shape: (b, num_heads, num_tokens, num_tokens)\n",
 | 
						||
    "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
 | 
						||
    "\n",
 | 
						||
    "        # Original mask truncated to the number of tokens and converted to boolean\n",
 | 
						||
    "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
 | 
						||
    "\n",
 | 
						||
    "        # Use the mask to fill attention scores\n",
 | 
						||
    "        attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
 | 
						||
    "\n",
 | 
						||
    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
 | 
						||
    "        assert keys.shape[-1] == self.head_dim\n",
 | 
						||
    "\n",
 | 
						||
    "        # Shape: (b, num_tokens, num_heads, head_dim)\n",
 | 
						||
    "        context_vec = (attn_weights @ values).transpose(1, 2)\n",
 | 
						||
    "\n",
 | 
						||
    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
 | 
						||
    "        context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n",
 | 
						||
    "        context_vec = self.out_proj(context_vec)  # optional projection\n",
 | 
						||
    "\n",
 | 
						||
    "        return context_vec"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "roAXSwJs9hR8",
 | 
						||
   "metadata": {
 | 
						||
    "id": "roAXSwJs9hR8"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- To illustrate the parameter savings, consider the following multi-head attention example from the GPT and Llama 2 code:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 9,
 | 
						||
   "id": "b4b8f085-349e-4674-a3f0-78fde0664fac",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "b4b8f085-349e-4674-a3f0-78fde0664fac",
 | 
						||
    "outputId": "9da09d72-43b1-45af-d46f-6928ea4af33a"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "W_key: torch.Size([4096, 4096])\n",
 | 
						||
      "W_value: torch.Size([4096, 4096])\n",
 | 
						||
      "W_query: torch.Size([4096, 4096])\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "# Settings\n",
 | 
						||
    "batch_size = 1\n",
 | 
						||
    "context_len = 3000\n",
 | 
						||
    "max_context_len = 8192\n",
 | 
						||
    "embed_dim = 4096\n",
 | 
						||
    "num_heads = 32\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "example_batch = torch.randn((batch_size, context_len, embed_dim))\n",
 | 
						||
    "\n",
 | 
						||
    "mha = MultiHeadAttention(\n",
 | 
						||
    "    d_in=embed_dim,\n",
 | 
						||
    "    d_out=embed_dim,\n",
 | 
						||
    "    context_length=max_context_len,\n",
 | 
						||
    "    num_heads=num_heads\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "mha(example_batch)\n",
 | 
						||
    "\n",
 | 
						||
    "print(\"W_key:\", mha.W_key.weight.shape)\n",
 | 
						||
    "print(\"W_value:\", mha.W_value.weight.shape)\n",
 | 
						||
    "print(\"W_query:\", mha.W_query.weight.shape)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "IMQtFkcQ9sXC",
 | 
						||
   "metadata": {
 | 
						||
    "id": "IMQtFkcQ9sXC"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Now, if we use grouped-query attention instead, with 8 kv-groups (that's how many Llama 3 8B uses), we can see that the number of rows of the key and value matrices are reduced by a factor of 4 (because 32 attention heads divided by 8 kv-groups is 4)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 10,
 | 
						||
   "id": "15e65d3c-7b42-4ed3-bfee-bb09578657bb",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "15e65d3c-7b42-4ed3-bfee-bb09578657bb",
 | 
						||
    "outputId": "69709a78-2aaa-4597-8142-2f44eb59753f"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "W_key: torch.Size([1024, 4096])\n",
 | 
						||
      "W_value: torch.Size([1024, 4096])\n",
 | 
						||
      "W_query: torch.Size([4096, 4096])\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "gqa = GroupedQueryAttention(\n",
 | 
						||
    "    d_in=embed_dim,\n",
 | 
						||
    "    d_out=embed_dim,\n",
 | 
						||
    "    context_length=max_context_len,\n",
 | 
						||
    "    num_heads=num_heads,\n",
 | 
						||
    "    num_kv_groups=8,\n",
 | 
						||
    "    rope_base=llama_3_theta_base\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "gqa(example_batch)\n",
 | 
						||
    "\n",
 | 
						||
    "print(\"W_key:\", gqa.W_key.weight.shape)\n",
 | 
						||
    "print(\"W_value:\", gqa.W_value.weight.shape)\n",
 | 
						||
    "print(\"W_query:\", gqa.W_query.weight.shape)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "1a5d4c88-c66a-483b-b4e2-419ff9fd60d5",
 | 
						||
   "metadata": {
 | 
						||
    "id": "1a5d4c88-c66a-483b-b4e2-419ff9fd60d5"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- As a side note, to make the GroupedQueryAttention equivalent to standard multi-head attention, you can set the number of query groups (`num_kv_groups`) equal to the number of heads (`num_heads`)\n",
 | 
						||
    "- Lastly, let's compare the number of parameters below:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 11,
 | 
						||
   "id": "58f713aa-ac00-4e2f-8247-94609aa01350",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "58f713aa-ac00-4e2f-8247-94609aa01350",
 | 
						||
    "outputId": "486dfd9c-9f3a-4b9e-f9a2-35fb43b9a5fb"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Total number of parameters:\n",
 | 
						||
      "MHA: 67,108,864\n",
 | 
						||
      "GQA: 41,943,040\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "print(\"Total number of parameters:\")\n",
 | 
						||
    "\n",
 | 
						||
    "mha_total_params = sum(p.numel() for p in mha.parameters())\n",
 | 
						||
    "print(f\"MHA: {mha_total_params:,}\")\n",
 | 
						||
    "\n",
 | 
						||
    "gqa_total_params = sum(p.numel() for p in gqa.parameters())\n",
 | 
						||
    "print(f\"GQA: {gqa_total_params:,}\")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 12,
 | 
						||
   "id": "78b60dfd-6c0f-41f7-8f0c-8e57116f07f5",
 | 
						||
   "metadata": {
 | 
						||
    "id": "78b60dfd-6c0f-41f7-8f0c-8e57116f07f5"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# Free up memory:\n",
 | 
						||
    "del mha\n",
 | 
						||
    "del gqa"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "8fcd8802-2859-45a2-905a-f4fe96629dd9",
 | 
						||
   "metadata": {
 | 
						||
    "id": "8fcd8802-2859-45a2-905a-f4fe96629dd9"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "## 1.4 Update the TransformerBlock module"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "KABNccft_YnR",
 | 
						||
   "metadata": {
 | 
						||
    "id": "KABNccft_YnR"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Next, we update the `TransformerBlock`\n",
 | 
						||
    "- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 13,
 | 
						||
   "id": "f9fa8eb4-7196-4dee-aec6-0dcbc70921c4",
 | 
						||
   "metadata": {
 | 
						||
    "id": "f9fa8eb4-7196-4dee-aec6-0dcbc70921c4"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "class TransformerBlock(nn.Module):\n",
 | 
						||
    "    def __init__(self, cfg):\n",
 | 
						||
    "        super().__init__()\n",
 | 
						||
    "        self.att =  GroupedQueryAttention(  # MultiHeadAttention(\n",
 | 
						||
    "            d_in=cfg[\"emb_dim\"],\n",
 | 
						||
    "            d_out=cfg[\"emb_dim\"],\n",
 | 
						||
    "            context_length=cfg[\"context_length\"],\n",
 | 
						||
    "            num_heads=cfg[\"n_heads\"],\n",
 | 
						||
    "            num_kv_groups=cfg[\"n_kv_groups\"],  # NEW\n",
 | 
						||
    "            rope_base=cfg[\"rope_base\"],        # NEW\n",
 | 
						||
    "            rope_config=cfg[\"rope_freq\"],      # NEW\n",
 | 
						||
    "            dtype=cfg[\"dtype\"]\n",
 | 
						||
    "        )\n",
 | 
						||
    "        self.ff = FeedForward(cfg)\n",
 | 
						||
    "        self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
 | 
						||
    "        self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
 | 
						||
    "\n",
 | 
						||
    "    def forward(self, x):\n",
 | 
						||
    "        # Shortcut connection for attention block\n",
 | 
						||
    "        shortcut = x\n",
 | 
						||
    "        x = self.norm1(x)\n",
 | 
						||
    "        x = self.att(x.to(torch.bfloat16))   # Shape [batch_size, num_tokens, emb_size]\n",
 | 
						||
    "        x = x + shortcut  # Add the original input back\n",
 | 
						||
    "\n",
 | 
						||
    "        # Shortcut connection for feed-forward block\n",
 | 
						||
    "        shortcut = x\n",
 | 
						||
    "        x = self.norm2(x)\n",
 | 
						||
    "        x = self.ff(x.to(torch.bfloat16))\n",
 | 
						||
    "        x = x + shortcut  # Add the original input back\n",
 | 
						||
    "\n",
 | 
						||
    "        return x"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "fd921ab5-c48c-4c52-bf41-b847b3b822b9",
 | 
						||
   "metadata": {
 | 
						||
    "id": "fd921ab5-c48c-4c52-bf41-b847b3b822b9"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "## 1.5 Defining the model class"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "M_tLAq_r_llN",
 | 
						||
   "metadata": {
 | 
						||
    "id": "M_tLAq_r_llN"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- When setting up the model class, we fortunately don't have to do much; we just update the name to `Llama3Model`"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 14,
 | 
						||
   "id": "475755d6-01f7-4e6e-ad9a-cec6f031ebf6",
 | 
						||
   "metadata": {
 | 
						||
    "id": "475755d6-01f7-4e6e-ad9a-cec6f031ebf6"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# class Llama2Model(nn.Module):\n",
 | 
						||
    "class Llama3Model(nn.Module):\n",
 | 
						||
    "    def __init__(self, cfg):\n",
 | 
						||
    "        super().__init__()\n",
 | 
						||
    "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
 | 
						||
    "\n",
 | 
						||
    "        self.trf_blocks = nn.Sequential(\n",
 | 
						||
    "            *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
 | 
						||
    "\n",
 | 
						||
    "        self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
 | 
						||
    "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
 | 
						||
    "\n",
 | 
						||
    "    def forward(self, in_idx):\n",
 | 
						||
    "        tok_embeds = self.tok_emb(in_idx)\n",
 | 
						||
    "        x = tok_embeds\n",
 | 
						||
    "        x = self.trf_blocks(x)\n",
 | 
						||
    "        x = self.final_norm(x)\n",
 | 
						||
    "        logits = self.out_head(x.to(torch.bfloat16))\n",
 | 
						||
    "        return logits"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60",
 | 
						||
   "metadata": {
 | 
						||
    "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "## 2. Initialize model"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "HoGGRAGykQTE",
 | 
						||
   "metadata": {
 | 
						||
    "id": "HoGGRAGykQTE"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Now we can define a Llama 3 config file (the Llama 2 config file is shown for comparison)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 15,
 | 
						||
   "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18",
 | 
						||
   "metadata": {
 | 
						||
    "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "LLAMA2_CONFIG_7B = {\n",
 | 
						||
    "    \"vocab_size\": 32_000,    # Vocabulary size\n",
 | 
						||
    "    \"context_length\": 4096,  # Context length\n",
 | 
						||
    "    \"emb_dim\": 4096,         # Embedding dimension\n",
 | 
						||
    "    \"n_heads\": 32,           # Number of attention heads\n",
 | 
						||
    "    \"n_layers\": 32,          # Number of layers\n",
 | 
						||
    "    \"hidden_dim\": 11_008,    # Size of the intermediate dimension in FeedForward\n",
 | 
						||
    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to reduce memory usage\n",
 | 
						||
    "}"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 16,
 | 
						||
   "id": "2ad90f82-15c7-4806-b509-e45b56f57db5",
 | 
						||
   "metadata": {
 | 
						||
    "id": "2ad90f82-15c7-4806-b509-e45b56f57db5"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "LLAMA3_CONFIG_8B = {\n",
 | 
						||
    "    \"vocab_size\": 128_256,   # NEW: Larger vocabulary size\n",
 | 
						||
    "    \"context_length\": 8192,  # NEW: Larger context length\n",
 | 
						||
    "    \"emb_dim\": 4096,         # Embedding dimension\n",
 | 
						||
    "    \"n_heads\": 32,           # Number of attention heads\n",
 | 
						||
    "    \"n_layers\": 32,          # Number of layers\n",
 | 
						||
    "    \"hidden_dim\": 14_336,    # NEW: Larger size of the intermediate dimension in FeedForward\n",
 | 
						||
    "    \"n_kv_groups\": 8,        # NEW: Key-Value groups for grouped-query attention\n",
 | 
						||
    "    \"rope_base\": 500_000.0,  # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
 | 
						||
    "    \"rope_freq\": None,       # NEW: Additional configuration for adjusting the RoPE frequencies\n",
 | 
						||
    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to reduce memory usage\n",
 | 
						||
    "}"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "FAP7fiBzkaBz",
 | 
						||
   "metadata": {
 | 
						||
    "id": "FAP7fiBzkaBz"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Using these settings, we can now initialize a Llama 3 8B model\n",
 | 
						||
    "- Note that this requires ~34 GB of memory (for comparison, Llama 2 7B required ~26 GB of memory)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 17,
 | 
						||
   "id": "7004d785-ac9a-4df5-8760-6807fc604686",
 | 
						||
   "metadata": {
 | 
						||
    "id": "7004d785-ac9a-4df5-8760-6807fc604686"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "model = Llama3Model(LLAMA3_CONFIG_8B)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": null,
 | 
						||
   "id": "ee9625cc-9afa-4b11-8aab-d536fd170761",
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# Check buffers\n",
 | 
						||
    "print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
 | 
						||
    "print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
 | 
						||
    "print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "8056a521-91a6-440f-8473-591409c3177b",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "- Let's now also compute the number of trainable parameters:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 18,
 | 
						||
   "id": "6079f747-8f20-4c6b-8d38-7156f1101729",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "6079f747-8f20-4c6b-8d38-7156f1101729",
 | 
						||
    "outputId": "0a8cd23b-d9fa-4c2d-ca63-3fc79bc4de0d"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Total number of parameters: 8,030,261,248\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "total_params = sum(p.numel() for p in model.parameters())\n",
 | 
						||
    "print(f\"Total number of parameters: {total_params:,}\")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "Bx14NtzWk2wj",
 | 
						||
   "metadata": {
 | 
						||
    "id": "Bx14NtzWk2wj"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- As shown above, the model contains 8 billion parameters\n",
 | 
						||
    "- Additionally, we can calculate the memory requirements for this model using the code below:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 19,
 | 
						||
   "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
 | 
						||
    "outputId": "3425e9ce-d8c0-4b37-bded-a2c60b66a41a"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "float32 (PyTorch default): 68.08 GB\n",
 | 
						||
      "bfloat16: 34.04 GB\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "def model_memory_size(model, input_dtype=torch.float32):\n",
 | 
						||
    "    total_params = 0\n",
 | 
						||
    "    total_grads = 0\n",
 | 
						||
    "    for param in model.parameters():\n",
 | 
						||
    "        # Calculate total number of elements per parameter\n",
 | 
						||
    "        param_size = param.numel()\n",
 | 
						||
    "        total_params += param_size\n",
 | 
						||
    "        # Check if gradients are stored for this parameter\n",
 | 
						||
    "        if param.requires_grad:\n",
 | 
						||
    "            total_grads += param_size\n",
 | 
						||
    "\n",
 | 
						||
    "    # Calculate buffer size (non-parameters that require memory)\n",
 | 
						||
    "    total_buffers = sum(buf.numel() for buf in model.buffers())\n",
 | 
						||
    "\n",
 | 
						||
    "    # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
 | 
						||
    "    # We assume parameters and gradients are stored in the same type as input dtype\n",
 | 
						||
    "    element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
 | 
						||
    "    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
 | 
						||
    "\n",
 | 
						||
    "    # Convert bytes to gigabytes\n",
 | 
						||
    "    total_memory_gb = total_memory_bytes / (1024**3)\n",
 | 
						||
    "\n",
 | 
						||
    "    return total_memory_gb\n",
 | 
						||
    "\n",
 | 
						||
    "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
 | 
						||
    "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "zudd-5PulKFL",
 | 
						||
   "metadata": {
 | 
						||
    "id": "zudd-5PulKFL"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Lastly, we can also transfer the model to an NVIDIA or Apple Silicon GPU if applicable:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 20,
 | 
						||
   "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d",
 | 
						||
   "metadata": {
 | 
						||
    "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "if torch.cuda.is_available():\n",
 | 
						||
    "    device = torch.device(\"cuda\")\n",
 | 
						||
    "elif torch.backends.mps.is_available():\n",
 | 
						||
    "    device = torch.device(\"mps\")\n",
 | 
						||
    "else:\n",
 | 
						||
    "    device = torch.device(\"cpu\")\n",
 | 
						||
    "\n",
 | 
						||
    "model.to(device);"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34",
 | 
						||
   "metadata": {
 | 
						||
    "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "## 3. Load tokenizer"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005",
 | 
						||
   "metadata": {
 | 
						||
    "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- In this section, we are going to load the tokenizer for the model\n",
 | 
						||
    "- Llama 2 used Google's [SentencePiece](https://github.com/google/sentencepiece) tokenizer instead of OpenAI's BPE tokenizer based on the [Tiktoken](https://github.com/openai/tiktoken) library\n",
 | 
						||
    "- Llama 3, however, reverted back to using the BPE tokenizer from Tiktoken; specifically, it uses the GPT-4 tokenizer with an extended vocabulary\n",
 | 
						||
    "- You can find the original Tiktoken-adaptation by Meta AI [here](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py) in their official Llama 3 repository\n",
 | 
						||
    "- Below, I rewrote the tokenizer code to make it more readable and minimal for this notebook (but the behavior should be similar)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 21,
 | 
						||
   "id": "5f390cbf-8f92-46dc-afe3-d90b5affae10",
 | 
						||
   "metadata": {
 | 
						||
    "id": "5f390cbf-8f92-46dc-afe3-d90b5affae10"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "from pathlib import Path\n",
 | 
						||
    "\n",
 | 
						||
    "import tiktoken\n",
 | 
						||
    "from tiktoken.load import load_tiktoken_bpe\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "class Tokenizer:\n",
 | 
						||
    "    def __init__(self, model_path):\n",
 | 
						||
    "        assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
 | 
						||
    "        mergeable_ranks = load_tiktoken_bpe(model_path)\n",
 | 
						||
    "\n",
 | 
						||
    "        self.special_tokens = {\n",
 | 
						||
    "            \"<|begin_of_text|>\": 128000,\n",
 | 
						||
    "            \"<|end_of_text|>\": 128001,\n",
 | 
						||
    "            \"<|start_header_id|>\": 128006,\n",
 | 
						||
    "            \"<|end_header_id|>\": 128007,\n",
 | 
						||
    "            \"<|eot_id|>\": 128009,\n",
 | 
						||
    "        }\n",
 | 
						||
    "        self.special_tokens.update({\n",
 | 
						||
    "            f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n",
 | 
						||
    "        })\n",
 | 
						||
    "\n",
 | 
						||
    "        self.model = tiktoken.Encoding(\n",
 | 
						||
    "            name=Path(model_path).name,\n",
 | 
						||
    "            pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
 | 
						||
    "            mergeable_ranks=mergeable_ranks,\n",
 | 
						||
    "            special_tokens=self.special_tokens\n",
 | 
						||
    "        )\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "    def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n",
 | 
						||
    "        if bos:\n",
 | 
						||
    "            tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
 | 
						||
    "        else:\n",
 | 
						||
    "            tokens = []\n",
 | 
						||
    "\n",
 | 
						||
    "        tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
 | 
						||
    "\n",
 | 
						||
    "        if eos:\n",
 | 
						||
    "            tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n",
 | 
						||
    "        return tokens\n",
 | 
						||
    "\n",
 | 
						||
    "    def decode(self, tokens):\n",
 | 
						||
    "        return self.model.decode(tokens)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "0a1509f8-8778-4fec-ba32-14d95c646167",
 | 
						||
   "metadata": {
 | 
						||
    "id": "0a1509f8-8778-4fec-ba32-14d95c646167"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Meta AI shared the original Llama 3 model weights and tokenizer vocabulary on the Hugging Face Hub\n",
 | 
						||
    "- We will first download the tokenizer vocabulary from the Hub and load it into the code above"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "KbnlzsbYmJU6",
 | 
						||
   "metadata": {
 | 
						||
    "id": "KbnlzsbYmJU6"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Please note that Meta AI requires that you accept the Llama 3 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) repository to accept the terms\n",
 | 
						||
    "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/settings.webp?1\" width=\"300px\">\n",
 | 
						||
    "\n",
 | 
						||
    "- Then, create and copy the access token so you can copy & paste it into the next code cell\n",
 | 
						||
    "\n",
 | 
						||
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/access-token.webp?1\" width=\"600px\">"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 22,
 | 
						||
   "id": "3357a230-b678-4691-a238-257ee4e80185",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "3357a230-b678-4691-a238-257ee4e80185",
 | 
						||
    "outputId": "a3652def-ea7f-46fb-f293-2a59affb71a0"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
 | 
						||
      "Token is valid (permission: read).\n",
 | 
						||
      "Your token has been saved to /root/.cache/huggingface/token\n",
 | 
						||
      "Login successful\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from huggingface_hub import login\n",
 | 
						||
    "import json\n",
 | 
						||
    "\n",
 | 
						||
    "with open(\"config.json\", \"r\") as config_file:\n",
 | 
						||
    "    config = json.load(config_file)\n",
 | 
						||
    "    access_token = config[\"HF_ACCESS_TOKEN\"]\n",
 | 
						||
    "\n",
 | 
						||
    "login(token=access_token)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "IxGh6ZYQo0VN",
 | 
						||
   "metadata": {
 | 
						||
    "id": "IxGh6ZYQo0VN"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- After login via the access token, which is necessary to verify that we accepted the Llama 3 licensing terms, we can now download the tokenizer vocabulary:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 23,
 | 
						||
   "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
 | 
						||
    "outputId": "c9836ba8-5176-4dd5-b618-6cc36fdbe1f0"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "from huggingface_hub import hf_hub_download\n",
 | 
						||
    "\n",
 | 
						||
    "tokenizer_file_path = hf_hub_download(\n",
 | 
						||
    "    repo_id=\"meta-llama/Meta-Llama-3-8B\",\n",
 | 
						||
    "    filename=\"original/tokenizer.model\",\n",
 | 
						||
    "    local_dir=\"Llama-3-8B\"\n",
 | 
						||
    ")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "F8BH1Nk0AYCS",
 | 
						||
   "metadata": {
 | 
						||
    "id": "F8BH1Nk0AYCS"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Note that for using Llama 3 files, we may need the `blobfile` package, which is used when handling datasets or models stored in cloud storage solutions like Google Cloud Storage (GCS), Azure Blob Storage, or Amazon S3\n",
 | 
						||
    "- You can install this dependency by uncommenting and executing the `pip` command below\n"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 24,
 | 
						||
   "id": "5dm6Oz7uAytV",
 | 
						||
   "metadata": {
 | 
						||
    "id": "5dm6Oz7uAytV"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# pip install blobfile"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 25,
 | 
						||
   "id": "8b8c0ce6-a6fb-4b8a-8de2-ee7bb7646fd0",
 | 
						||
   "metadata": {
 | 
						||
    "id": "8b8c0ce6-a6fb-4b8a-8de2-ee7bb7646fd0"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "tokenizer = Tokenizer(tokenizer_file_path)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "NVhmFeX3pT_M",
 | 
						||
   "metadata": {
 | 
						||
    "id": "NVhmFeX3pT_M"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- We can now use the `generate` function to have the Llama 3 model generate new text:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 26,
 | 
						||
   "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
 | 
						||
    "outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Output text:\n",
 | 
						||
      " Every effort_dead aeros Ingredients başında.extensionégor clangmissions güc như submodule.and report官方%,.Reader(\",\");\n",
 | 
						||
      "ामल ندار Parliamentary !!! HigginsDynamicZhgmt writeln Globalsletion 사진------\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from previous_chapters import generate, text_to_token_ids, token_ids_to_text\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "torch.manual_seed(123)\n",
 | 
						||
    "\n",
 | 
						||
    "token_ids = generate(\n",
 | 
						||
    "    model=model,\n",
 | 
						||
    "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
 | 
						||
    "    max_new_tokens=30,\n",
 | 
						||
    "    context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n",
 | 
						||
    "    top_k=1,\n",
 | 
						||
    "    temperature=0.\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "93WTtAA5paYV",
 | 
						||
   "metadata": {
 | 
						||
    "id": "93WTtAA5paYV"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Of course, as we can see above, the text is nonsensical since we haven't trained the Llama 3 model yet\n",
 | 
						||
    "- In the next section, instead of training it ourselves, which would cost tens to hundreds of thousands of dollars, we load the pretrained weights from Meta AI"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "f63cc248-1d27-4eb6-aa50-173b436652f8",
 | 
						||
   "metadata": {
 | 
						||
    "id": "f63cc248-1d27-4eb6-aa50-173b436652f8"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "## 4. Load pretrained weights"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "aKeN7rUfqZMI",
 | 
						||
   "metadata": {
 | 
						||
    "id": "aKeN7rUfqZMI"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- We are loading the [\"meta-llama/Meta-Llama-3-8B\"](https://huggingface.co/meta-llama/Meta-Llama-3-8B) base model below, which is a simple text completion model before finetuning\n",
 | 
						||
    "- Alternatively, you can load the instruction-finetuned and aligned [\"meta-llama/Meta-Llama-3-8B-Instruct\"](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model by modifying the string in the next code cell accordingly\n",
 | 
						||
    "- Combined, the weight files are about 16 GB large"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 27,
 | 
						||
   "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/",
 | 
						||
     "height": 145,
 | 
						||
     "referenced_widgets": [
 | 
						||
      "f3788acce34f4956b0727b58d0cf38c6",
 | 
						||
      "6022a9426683420690d9b41a0ca4f870",
 | 
						||
      "e9aba3d53b4d45c485a7aad649c7b465",
 | 
						||
      "f1a12d7929db4309b9881853135359fc",
 | 
						||
      "58c9dec75a3346b1b787f88dd510d254",
 | 
						||
      "9492edc02dee456f840325d913fa4e4f",
 | 
						||
      "66dc94b23556499f985f8accbb1f89cb",
 | 
						||
      "7c6658cfff1a4d27af3de148184f77d9",
 | 
						||
      "7266a729edfb4a44b5b1c67dc79be146",
 | 
						||
      "76dbab4873f342019c5d7624ae2c9775",
 | 
						||
      "3cea4b431147441a8d9bd872811d5974",
 | 
						||
      "8ae98969541849efa356cf912ac39b1e",
 | 
						||
      "f9373112649945e3b446c3e1ec274dc1",
 | 
						||
      "d49791082a304ade95c185c79fae1f41",
 | 
						||
      "616e383bb3d442bcb6edb2721a8180b6",
 | 
						||
      "87f474861e54432e9d533e0a89bb77da",
 | 
						||
      "e805bb6dfee34dab8870f4618d8bffdb",
 | 
						||
      "be3e9bf271f04eb0b119659e1af3a0ea",
 | 
						||
      "00148825ce0248b7a23eb28e3eca6749",
 | 
						||
      "f1a9b0c2431640298a6c1b258298b12d",
 | 
						||
      "8ba9f009e92a46fcbcbb401dc444f12e",
 | 
						||
      "d74186bb74d142dfb683fa347b6990f7",
 | 
						||
      "9bb60a5a3710463ebe3a17f8d2a446be",
 | 
						||
      "0a08fb81165748748ccb080e6df0600f",
 | 
						||
      "603690f543114a7fb6aebd433c80bdc3",
 | 
						||
      "773b802daed942f5a11f3eab3b83be08",
 | 
						||
      "7989003a613e45f780d3f800e121543a",
 | 
						||
      "9d49589118f5432cac49650251046429",
 | 
						||
      "f114549fe8ce49638a791ca2fecb2d89",
 | 
						||
      "0aa155b794a8426aa265f4a7670f43ad",
 | 
						||
      "a06fbde549cc47fdaddfbdb82d35d823",
 | 
						||
      "172c0c6955e1428b999dcb2d133704cd",
 | 
						||
      "1bf7108774c34016a2193e2cd7639b7d",
 | 
						||
      "ed28e180d94a4b7aa548581612e31232",
 | 
						||
      "ff4338faded5494da1ccb660e1c441ed",
 | 
						||
      "b46a08cf4929422eb0f76d8d9af11249",
 | 
						||
      "f049eb4a50f54c34912ca959d2eaf353",
 | 
						||
      "80dfd3e80ceb444a83ec1fd65f9af80e",
 | 
						||
      "519147a10b984befbd0f255f78c1f66a",
 | 
						||
      "562e82438dbe41b793ff488b8447c5bf",
 | 
						||
      "1da83719e47c4196b06f3aa32056b560",
 | 
						||
      "c4a2c88326d14fbca87cfde073755a2e",
 | 
						||
      "f0ab5a46cbb0444c88ed137d8a95002b",
 | 
						||
      "f8f28ac0e149428f9fef42373c6a87d0"
 | 
						||
     ]
 | 
						||
    },
 | 
						||
    "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
 | 
						||
    "outputId": "c05118ce-9f81-41c8-a1f2-72caa932ae86"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "245443330e4d40c887a5649cc1663e98",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from safetensors.torch import load_file\n",
 | 
						||
    "\n",
 | 
						||
    "combined_weights = {}\n",
 | 
						||
    "\n",
 | 
						||
    "for i in range(1, 5):\n",
 | 
						||
    "    weights_file = hf_hub_download(\n",
 | 
						||
    "        repo_id=\"meta-llama/Meta-Llama-3-8B\",\n",
 | 
						||
    "        filename=f\"model-0000{i}-of-00004.safetensors\",\n",
 | 
						||
    "        local_dir=\"Llama-3-8B\"\n",
 | 
						||
    "    )\n",
 | 
						||
    "    current_weights = load_file(weights_file)\n",
 | 
						||
    "    combined_weights.update(current_weights)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "-15SJ7btq2zE",
 | 
						||
   "metadata": {
 | 
						||
    "id": "-15SJ7btq2zE"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- The `weights` contains the following tensors (only the first 15 are shown for simplicity):"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 28,
 | 
						||
   "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
 | 
						||
    "outputId": "2fbc2786-677f-4fea-9472-5fb8542ff14b"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "text/plain": [
 | 
						||
       "['model.embed_tokens.weight',\n",
 | 
						||
       " 'model.layers.0.input_layernorm.weight',\n",
 | 
						||
       " 'model.layers.0.mlp.down_proj.weight',\n",
 | 
						||
       " 'model.layers.0.mlp.gate_proj.weight',\n",
 | 
						||
       " 'model.layers.0.mlp.up_proj.weight',\n",
 | 
						||
       " 'model.layers.0.post_attention_layernorm.weight',\n",
 | 
						||
       " 'model.layers.0.self_attn.k_proj.weight',\n",
 | 
						||
       " 'model.layers.0.self_attn.o_proj.weight',\n",
 | 
						||
       " 'model.layers.0.self_attn.q_proj.weight',\n",
 | 
						||
       " 'model.layers.0.self_attn.v_proj.weight',\n",
 | 
						||
       " 'model.layers.1.input_layernorm.weight',\n",
 | 
						||
       " 'model.layers.1.mlp.down_proj.weight',\n",
 | 
						||
       " 'model.layers.1.mlp.gate_proj.weight',\n",
 | 
						||
       " 'model.layers.1.mlp.up_proj.weight',\n",
 | 
						||
       " 'model.layers.1.post_attention_layernorm.weight']"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "execution_count": 28,
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "execute_result"
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "list(combined_weights.keys())[:15]"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "UeeSpnunrDFB",
 | 
						||
   "metadata": {
 | 
						||
    "id": "UeeSpnunrDFB"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- The following function, modeled after the `load_weights_into_gpt` function in [chapter 5](../01_main-chapter-code/ch05.ipynb), loads the pretrained weights into our Llama 3 model:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 29,
 | 
						||
   "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
 | 
						||
   "metadata": {
 | 
						||
    "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "def assign(left, right, tensor_name=\"unknown\"):\n",
 | 
						||
    "    if left.shape != right.shape:\n",
 | 
						||
    "        raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
 | 
						||
    "\n",
 | 
						||
    "    if isinstance(right, torch.Tensor):\n",
 | 
						||
    "        return torch.nn.Parameter(right.clone().detach())\n",
 | 
						||
    "    else:\n",
 | 
						||
    "        return torch.nn.Parameter(torch.tensor(right))\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "def load_weights_into_llama(model, param_config, params):\n",
 | 
						||
    "    model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
 | 
						||
    "\n",
 | 
						||
    "    for l in range(param_config[\"n_layers\"]):\n",
 | 
						||
    "\n",
 | 
						||
    "        # Load attention weights\n",
 | 
						||
    "        model.trf_blocks[l].att.W_query.weight = assign(\n",
 | 
						||
    "            model.trf_blocks[l].att.W_query.weight,\n",
 | 
						||
    "            params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
 | 
						||
    "            f\"model.layers.{l}.self_attn.q_proj.weight\"\n",
 | 
						||
    "        )\n",
 | 
						||
    "        model.trf_blocks[l].att.W_key.weight = assign(\n",
 | 
						||
    "            model.trf_blocks[l].att.W_key.weight,\n",
 | 
						||
    "            params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
 | 
						||
    "            f\"model.layers.{l}.self_attn.k_proj.weight\"\n",
 | 
						||
    "        )\n",
 | 
						||
    "        model.trf_blocks[l].att.W_value.weight = assign(\n",
 | 
						||
    "            model.trf_blocks[l].att.W_value.weight,\n",
 | 
						||
    "            params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
 | 
						||
    "            f\"model.layers.{l}.self_attn.v_proj.weight\"\n",
 | 
						||
    "        )\n",
 | 
						||
    "        model.trf_blocks[l].att.out_proj.weight = assign(\n",
 | 
						||
    "            model.trf_blocks[l].att.out_proj.weight,\n",
 | 
						||
    "            params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
 | 
						||
    "            f\"model.layers.{l}.self_attn.o_proj.weight\"\n",
 | 
						||
    "        )\n",
 | 
						||
    "        model.trf_blocks[l].norm1.weight = assign(\n",
 | 
						||
    "            model.trf_blocks[l].norm1.weight,\n",
 | 
						||
    "            params[f\"model.layers.{l}.input_layernorm.weight\"],\n",
 | 
						||
    "            f\"model.layers.{l}.input_layernorm.weight\"\n",
 | 
						||
    "        )\n",
 | 
						||
    "\n",
 | 
						||
    "        # Load FeedForward weights\n",
 | 
						||
    "        model.trf_blocks[l].ff.fc1.weight = assign(\n",
 | 
						||
    "            model.trf_blocks[l].ff.fc1.weight,\n",
 | 
						||
    "            params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
 | 
						||
    "            f\"model.layers.{l}.mlp.gate_proj.weight\"\n",
 | 
						||
    "        )\n",
 | 
						||
    "        model.trf_blocks[l].ff.fc2.weight = assign(\n",
 | 
						||
    "            model.trf_blocks[l].ff.fc2.weight,\n",
 | 
						||
    "            params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
 | 
						||
    "            f\"model.layers.{l}.mlp.up_proj.weight\"\n",
 | 
						||
    "        )\n",
 | 
						||
    "        model.trf_blocks[l].ff.fc3.weight = assign(\n",
 | 
						||
    "            model.trf_blocks[l].ff.fc3.weight,\n",
 | 
						||
    "            params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
 | 
						||
    "            f\"model.layers.{l}.mlp.down_proj.weight\"\n",
 | 
						||
    "        )\n",
 | 
						||
    "        model.trf_blocks[l].norm2.weight = assign(\n",
 | 
						||
    "            model.trf_blocks[l].norm2.weight,\n",
 | 
						||
    "            params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
 | 
						||
    "            f\"model.layers.{l}.post_attention_layernorm.weight\"\n",
 | 
						||
    "        )\n",
 | 
						||
    "\n",
 | 
						||
    "    # Load output layer weights\n",
 | 
						||
    "    model.final_norm.weight = assign(model.final_norm.weight, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
 | 
						||
    "\n",
 | 
						||
    "    if \"lm_head.weight\" in params.keys():\n",
 | 
						||
    "        model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
 | 
						||
    "    else:\n",
 | 
						||
    "        model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
 | 
						||
    "        print(\"Model uses weight tying.\")\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "load_weights_into_llama(model, LLAMA3_CONFIG_8B, combined_weights)\n",
 | 
						||
    "model.to(device);\n",
 | 
						||
    "del combined_weights  # free up memory"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "TDuv_Us2rNvk",
 | 
						||
   "metadata": {
 | 
						||
    "id": "TDuv_Us2rNvk"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Next, we are ready to use the model for text generation"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 30,
 | 
						||
   "id": "240987e8-a023-462e-9376-9edfb27559ec",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "240987e8-a023-462e-9376-9edfb27559ec",
 | 
						||
    "outputId": "6dab0e56-40a8-45db-a096-ab2b9ee97a69"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Output text:\n",
 | 
						||
      " Every effort has been made to trace copyright holders and to obtain their permission for the use of copyright material. The publisher apologizes for any\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "torch.manual_seed(123)\n",
 | 
						||
    "\n",
 | 
						||
    "token_ids = generate(\n",
 | 
						||
    "    model=model,\n",
 | 
						||
    "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
 | 
						||
    "    max_new_tokens=25,\n",
 | 
						||
    "    context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n",
 | 
						||
    "    top_k=1,\n",
 | 
						||
    "    temperature=0.\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "1203041e-4794-4157-a978-3ce80909da44",
 | 
						||
   "metadata": {
 | 
						||
    "id": "1203041e-4794-4157-a978-3ce80909da44"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "## 5. Using the instruction-finetuned model"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "akyo7WNyF_YL",
 | 
						||
   "metadata": {
 | 
						||
    "id": "akyo7WNyF_YL"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Above, we used the pretrained base model; if you want to use a model capable of following instructions, use the `\"meta-llama/Llama-3-8B-Instruct\"` model instead, as shown below"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 31,
 | 
						||
   "id": "hdA-xjjdS26J",
 | 
						||
   "metadata": {
 | 
						||
    "id": "hdA-xjjdS26J"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# to free up memory\n",
 | 
						||
    "\n",
 | 
						||
    "import gc\n",
 | 
						||
    "\n",
 | 
						||
    "del model\n",
 | 
						||
    "\n",
 | 
						||
    "gc.collect()  # Run Python garbage collector\n",
 | 
						||
    "\n",
 | 
						||
    "if torch.cuda.is_available():\n",
 | 
						||
    "    torch.cuda.empty_cache()"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 32,
 | 
						||
   "id": "nbvAV7vaz6yc",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/",
 | 
						||
     "height": 145,
 | 
						||
     "referenced_widgets": [
 | 
						||
      "409470784b6346a981920350de4f6f28",
 | 
						||
      "9ba6a11ffd194bf9a0900f52a7ed4d4f",
 | 
						||
      "acae8bbbb4a84ed49be72fecd11fb052",
 | 
						||
      "e8a4b441281b4038bb0204d093411f68",
 | 
						||
      "bdf8b693821344fc97918e6cbc31c8bf",
 | 
						||
      "97e8877869cd4be68ff38ce745be5045",
 | 
						||
      "cc3da88e93c4499993b7bbb7d3064326",
 | 
						||
      "0d51fdc2c416474da04079db6579890f",
 | 
						||
      "c4598300a77b4667b1117f9499f5ccb7",
 | 
						||
      "77606cd2fe1b4d33a91ede944bb1dec0",
 | 
						||
      "f1ba439c26d64c90af2f162c74348405",
 | 
						||
      "d598f094c3ce4daeab19fac8094cba7e",
 | 
						||
      "0afc2d23514b45c9890b5d2ee4e6fa0b",
 | 
						||
      "3da5d38bf3314d3eaa7cedebae41c076",
 | 
						||
      "55e6b727a4594078beb3853cc1891308",
 | 
						||
      "f17fa78263414ef8b414c7bf3ac03192",
 | 
						||
      "e8b187b40ec14db3af17a380830a35bf",
 | 
						||
      "e94ca32eaa9f4714a3b05a5fdf24d02b",
 | 
						||
      "3edd464991204b8690eae02f10b4cc00",
 | 
						||
      "ac1e34f4bd6c420bb6cc2fdde5f3ed4d",
 | 
						||
      "1cd5e07cad35450182004952de32c8e7",
 | 
						||
      "a63351a6715643378491ba831b3fb05d",
 | 
						||
      "98b4680141ee423bb5e43c47613d8440",
 | 
						||
      "b02ffefca3f34252914e76f4a8a467dc",
 | 
						||
      "31d27bf34a74432f8e0dbfe9ecb76130",
 | 
						||
      "a3137f3669b54e84be91010c9654d985",
 | 
						||
      "5a2886564d3f40ceaa30b743dbe81f45",
 | 
						||
      "15ea8fcfe097471e8fc9502a162f5904",
 | 
						||
      "c779e80c50ba4434bfa1d326c5cc9b0f",
 | 
						||
      "eb94612785e64552aea8674dc8647a93",
 | 
						||
      "279cffe683fe4e7383062162e07ed9ed",
 | 
						||
      "6176990205cc499f8995c71fc6b9d4df",
 | 
						||
      "66c23ae98bcc45f18fc5c91e0e73c3e4",
 | 
						||
      "05b502e1e3a9436297dafbb1ce7af722",
 | 
						||
      "25977b0d89084703ad787fe9208b5aad",
 | 
						||
      "71a84ee5fc964ec89ff2832c84735cc2",
 | 
						||
      "6aed783eccb942318e6384e253ad4924",
 | 
						||
      "84c34bfecda64391a609e19f131d51d4",
 | 
						||
      "20ecac7c646b45938ed393cb20977c37",
 | 
						||
      "ebe04aeaaac042aaaa0885992e45793d",
 | 
						||
      "ca81071ab07446df96795a482ce0c630",
 | 
						||
      "e0550cab24c7492787af40dc4b8576bf",
 | 
						||
      "7015bf6f85954036aaf8cc4f1c44ea0f",
 | 
						||
      "2a2ba3d065634484a932b8d3c212af56"
 | 
						||
     ]
 | 
						||
    },
 | 
						||
    "id": "nbvAV7vaz6yc",
 | 
						||
    "outputId": "9e1badc9-a6c4-48b7-9125-e0810655528b"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "f7df6bbf8e63448c8a6cb5d2f6208403",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model-00001-of-00004.safetensors:  36%|###6      | 1.81G/4.98G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "4772f31a1c5b4c168c9aabe7a1d2bacc",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "ad49eeb9e1204ea2bd2e371df8ccdea2",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "951b9e81613a40a2a503f61e69677f0a",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "combined_weights = {}\n",
 | 
						||
    "\n",
 | 
						||
    "for i in range(1, 5):\n",
 | 
						||
    "    weights_file = hf_hub_download(\n",
 | 
						||
    "        repo_id=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
 | 
						||
    "        filename=f\"model-0000{i}-of-00004.safetensors\",\n",
 | 
						||
    "        local_dir=\"Llama-3-8B-Instruct\"\n",
 | 
						||
    "    )\n",
 | 
						||
    "    current_weights = load_file(weights_file)\n",
 | 
						||
    "    combined_weights.update(current_weights)\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "model = Llama3Model(LLAMA3_CONFIG_8B)\n",
 | 
						||
    "load_weights_into_llama(model, LLAMA3_CONFIG_8B, combined_weights)\n",
 | 
						||
    "model.to(device)\n",
 | 
						||
    "del combined_weights  # free up memory"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "VlH7qYVdDKQr",
 | 
						||
   "metadata": {
 | 
						||
    "id": "VlH7qYVdDKQr"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Note that the Llama 3 model should ideally be used with the correct prompt template that was used during finetuning (as discussed in chapter 7)\n",
 | 
						||
    "- Below is a wrapper class around the tokenizer based on Meta AI's Llama 3-specific [ChatFormat code](https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/tokenizer.py#L202) that constructs the prompt template"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 33,
 | 
						||
   "id": "4be5b481-1110-46e8-a931-3988d890cf8c",
 | 
						||
   "metadata": {
 | 
						||
    "id": "4be5b481-1110-46e8-a931-3988d890cf8c"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "class ChatFormat:\n",
 | 
						||
    "    def __init__(self, tokenizer):\n",
 | 
						||
    "        self.tokenizer = tokenizer\n",
 | 
						||
    "\n",
 | 
						||
    "    def encode_header(self, message):\n",
 | 
						||
    "        tokens = []\n",
 | 
						||
    "        tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n",
 | 
						||
    "        tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n",
 | 
						||
    "        tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
 | 
						||
    "        tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
 | 
						||
    "        return tokens\n",
 | 
						||
    "\n",
 | 
						||
    "    def encode(self, text):\n",
 | 
						||
    "        message = {\n",
 | 
						||
    "            \"role\": \"user\",\n",
 | 
						||
    "            \"content\": text\n",
 | 
						||
    "        }\n",
 | 
						||
    "\n",
 | 
						||
    "        tokens = self.encode_header(message)\n",
 | 
						||
    "        tokens.extend(\n",
 | 
						||
    "            self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
 | 
						||
    "        )\n",
 | 
						||
    "        tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
 | 
						||
    "        return tokens\n",
 | 
						||
    "\n",
 | 
						||
    "    def decode(self, token_ids):\n",
 | 
						||
    "        return self.tokenizer.decode(token_ids)\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "chat_tokenizer = ChatFormat(tokenizer)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "M-dkSNvwDttN",
 | 
						||
   "metadata": {
 | 
						||
    "id": "M-dkSNvwDttN"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- The usage is as follows:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 34,
 | 
						||
   "id": "nwBrTGTsUNhn",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "nwBrTGTsUNhn",
 | 
						||
    "outputId": "72a495b4-b872-429a-88ef-49a9b4577f0f"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "[128006, 882, 128007, 271, 9906, 4435, 0, 128009]\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "token_ids = chat_tokenizer.encode(\"Hello World!\")\n",
 | 
						||
    "print(token_ids)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 35,
 | 
						||
   "id": "0fpmpVgYVTRZ",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/",
 | 
						||
     "height": 36
 | 
						||
    },
 | 
						||
    "id": "0fpmpVgYVTRZ",
 | 
						||
    "outputId": "bb3e819a-112a-466c-ac51-5d14a9c3475b"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "text/plain": [
 | 
						||
       "'<|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|>'"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "execution_count": 35,
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "execute_result"
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "tokenizer.decode(token_ids)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "Wo-aUGeKDvqq",
 | 
						||
   "metadata": {
 | 
						||
    "id": "Wo-aUGeKDvqq"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Let's now see the Llama 3 instruction model in action:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 36,
 | 
						||
   "id": "ozGOBu6XOkEW",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "ozGOBu6XOkEW",
 | 
						||
    "outputId": "4f689c70-bed9-46f3-a52a-aea47b641283"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Output text:\n",
 | 
						||
      " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Here are some of the things llamas like to eat:\n",
 | 
						||
      "\n",
 | 
						||
      "1. Grass: Llamas love to graze on grass, especially in the spring and summer months.\n",
 | 
						||
      "2. Hay: Hay is a staple in a llama's diet. They like to eat timothy hay, alfalfa hay, and other types of hay.\n",
 | 
						||
      "3. Grains: Llamas may also be fed grains like oats, barley, and corn. However, grains should not make up more than 10-15% of a llama's diet.\n",
 | 
						||
      "4. Fruits and vegetables: Llamas may enjoy fruits and vegetables as treats, such as\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "torch.manual_seed(123)\n",
 | 
						||
    "\n",
 | 
						||
    "token_ids = generate(\n",
 | 
						||
    "    model=model,\n",
 | 
						||
    "    idx=text_to_token_ids(\"What do llamas eat?\", chat_tokenizer).to(device),\n",
 | 
						||
    "    max_new_tokens=150,\n",
 | 
						||
    "    context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n",
 | 
						||
    "    top_k=1,\n",
 | 
						||
    "    temperature=0.\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "output_text = token_ids_to_text(token_ids, tokenizer)\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n",
 | 
						||
    "    # Find the index of the first occurrence of \"<|end_header_id|>\"\n",
 | 
						||
    "    index = text.find(header_end)\n",
 | 
						||
    "\n",
 | 
						||
    "    if index != -1:\n",
 | 
						||
    "        # Return the substring starting after \"<|end_header_id|>\"\n",
 | 
						||
    "        return text[index + len(header_end):].strip()  # Strip removes leading/trailing whitespace\n",
 | 
						||
    "    else:\n",
 | 
						||
    "        # If the token is not found, return the original text\n",
 | 
						||
    "        return text\n",
 | 
						||
    "\n",
 | 
						||
    "print(\"Output text:\\n\", clean_text(output_text))"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "2r5JKrO-ZOHK",
 | 
						||
   "metadata": {
 | 
						||
    "id": "2r5JKrO-ZOHK"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "# Llama 3.1 8B"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "QiQxX0XnP_iC",
 | 
						||
   "metadata": {
 | 
						||
    "id": "QiQxX0XnP_iC"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- A few months after the initial Llama 3 release, Meta AI followed up with their Llama 3.1 suite of models (see the official [Introducing Llama 3.1: Our most capable models to date](https://ai.meta.com/blog/meta-llama-3-1/) announcement blog post for details)\n",
 | 
						||
    "- Conveniently, we can reuse our previous Llama 3 code from above to implement Llama 3.1 8B\n",
 | 
						||
    "\n",
 | 
						||
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/llama3-to-llama31.webp\" width=\"700px\">\n",
 | 
						||
    "\n",
 | 
						||
    "- The architecture is identical, with the only change being a rescaling of the RoPE frequencies as indicated in the configuration file below\n",
 | 
						||
    "\n"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 37,
 | 
						||
   "id": "X5Fg8XUHMv4M",
 | 
						||
   "metadata": {
 | 
						||
    "id": "X5Fg8XUHMv4M"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "LLAMA3_CONFIG_8B = {\n",
 | 
						||
    "    \"vocab_size\": 128_256,   # Vocabulary size\n",
 | 
						||
    "    \"context_length\": 8192,  # Context length\n",
 | 
						||
    "    \"emb_dim\": 4096,         # Embedding dimension\n",
 | 
						||
    "    \"n_heads\": 32,           # Number of attention heads\n",
 | 
						||
    "    \"n_layers\": 32,          # Number of layers\n",
 | 
						||
    "    \"hidden_dim\": 14_336,    # Size of the intermediate dimension in FeedForward\n",
 | 
						||
    "    \"n_kv_groups\": 8,        # Key-Value groups for grouped-query attention\n",
 | 
						||
    "    \"rope_base\": 500_000.0,  # The base in RoPE's \"theta\"\n",
 | 
						||
    "    \"rope_freq\": None,       # Additional configuration for adjusting the RoPE frequencies\n",
 | 
						||
    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to reduce memory usage\n",
 | 
						||
    "}\n",
 | 
						||
    "\n",
 | 
						||
    "LLAMA31_CONFIG_8B = {\n",
 | 
						||
    "    \"vocab_size\": 128_256,      # Vocabulary size\n",
 | 
						||
    "    \"context_length\": 131_072,  # NEW: Larger supported context length\n",
 | 
						||
    "    \"emb_dim\": 4096,            # Embedding dimension\n",
 | 
						||
    "    \"n_heads\": 32,              # Number of attention heads\n",
 | 
						||
    "    \"n_layers\": 32,             # Number of layers\n",
 | 
						||
    "    \"hidden_dim\": 14_336,       # Size of the intermediate dimension in FeedForward\n",
 | 
						||
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
						||
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
						||
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
 | 
						||
    "    \"rope_freq\": {              # NEW: RoPE frequency scaling\n",
 | 
						||
    "        \"factor\": 8.0,\n",
 | 
						||
    "        \"low_freq_factor\": 1.0,\n",
 | 
						||
    "        \"high_freq_factor\": 4.0,\n",
 | 
						||
    "        \"original_context_length\": 8192,\n",
 | 
						||
    "    }\n",
 | 
						||
    "}"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "d81ee464-c112-43b0-9ee8-70df6ac942d0",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 10,
 | 
						||
   "id": "a55a8769-1a03-4265-8fd0-15f1c423da53",
 | 
						||
   "metadata": {
 | 
						||
    "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "New RoPE theta: 31250.0\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "old_context_length = LLAMA31_CONFIG_8B[\"context_length\"]\n",
 | 
						||
    "LLAMA31_CONFIG_8B[\"context_length\"] = 8192\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "def rescale_theta(theta_old, context_length_old, context_length_new):\n",
 | 
						||
    "    scaling_factor = context_length_new / context_length_old\n",
 | 
						||
    "    theta_new = theta_old * scaling_factor\n",
 | 
						||
    "    return theta_new\n",
 | 
						||
    "\n",
 | 
						||
    "LLAMA31_CONFIG_8B[\"rope_base\"] = rescale_theta(\n",
 | 
						||
    "    LLAMA31_CONFIG_8B[\"rope_base\"],\n",
 | 
						||
    "    old_context_length,\n",
 | 
						||
    "    LLAMA31_CONFIG_8B[\"context_length\"]\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "print(\"New RoPE theta:\", LLAMA31_CONFIG_8B[\"rope_base\"])"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "xa3bpMDtTdBs",
 | 
						||
   "metadata": {
 | 
						||
    "id": "xa3bpMDtTdBs"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- As we've seen in the code earlier, the RoPE method uses sinusoidal functions (sine and cosine) to embed positional information directly into the attention mechanism\n",
 | 
						||
    "- In Llama 3.1, via the additional configuration, we introduce additional adjustments to the inverse frequency calculations\n",
 | 
						||
    "- These adjustments influence how different frequency components contribute to the positional embeddings (a detailed explanation is a topic for another time)\n",
 | 
						||
    "- Let's try out the Llama 3.1 model in practice; first, we clear out the old model to free up some GPU memory"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 38,
 | 
						||
   "id": "7dUtYnNUOqhL",
 | 
						||
   "metadata": {
 | 
						||
    "id": "7dUtYnNUOqhL"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# free up memory\n",
 | 
						||
    "del model\n",
 | 
						||
    "\n",
 | 
						||
    "gc.collect()  # Run Python garbage collector\n",
 | 
						||
    "\n",
 | 
						||
    "if torch.cuda.is_available():\n",
 | 
						||
    "    torch.cuda.empty_cache()"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "DbbVsll6TYWR",
 | 
						||
   "metadata": {
 | 
						||
    "id": "DbbVsll6TYWR"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Next, we download the tokenizer\n",
 | 
						||
    "- Note that since the Llama 3.1 family is distinct from the Llama 3 family, you'd have to go to the [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) repository and acknowledge the license terms for your Hugging Face access token to work for the download\n",
 | 
						||
    "- Tip: For simplicity, we only load the base model below, but there's also an instruction-finetuned version you can use by replacing `\"meta-llama/Llama-3.1-8B\"` with `\"meta-llama/Llama-3.1-8B-Instruct\"`"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 39,
 | 
						||
   "id": "8xDk4chtPNU4",
 | 
						||
   "metadata": {
 | 
						||
    "id": "8xDk4chtPNU4"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "tokenizer_file_path = hf_hub_download(\n",
 | 
						||
    "    repo_id=\"meta-llama/Llama-3.1-8B\",\n",
 | 
						||
    "    filename=\"original/tokenizer.model\",\n",
 | 
						||
    "    local_dir=\"Llama-3.1-8B\"\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "tokenizer = Tokenizer(tokenizer_file_path)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 40,
 | 
						||
   "id": "a7l21VE4Otcs",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "a7l21VE4Otcs",
 | 
						||
    "outputId": "3dd5cfba-bf3f-44d2-9be1-7cd42bfe4ba9"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Total number of parameters: 8,030,261,248\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "model = Llama3Model(LLAMA31_CONFIG_8B)\n",
 | 
						||
    "\n",
 | 
						||
    "total_params = sum(p.numel() for p in model.parameters())\n",
 | 
						||
    "print(f\"Total number of parameters: {total_params:,}\")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 41,
 | 
						||
   "id": "u4J7IxOvOyPM",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/",
 | 
						||
     "height": 145,
 | 
						||
     "referenced_widgets": [
 | 
						||
      "5bbaa046d8934c8fae0a12c3d7bd991b",
 | 
						||
      "e1e4125eac004bae92dc1f22f673bf0e",
 | 
						||
      "d5b4bb4891ec4e44be46e9815c7e10dc",
 | 
						||
      "4f6595a392b244bd8e887935defc06f0",
 | 
						||
      "100c1b15cc4046cea1147f657eb2d8d0",
 | 
						||
      "81458e7953a349cfafccaa213b370406",
 | 
						||
      "a3dc9dfadae642b4a873705596739468",
 | 
						||
      "f55b59efcefa4ad5955d082f4bf7c637",
 | 
						||
      "1b02e0c7d1604b1c87a327c4c4f8b0e7",
 | 
						||
      "02ad170019454fd096b37347de5c481d",
 | 
						||
      "c52e0f34892b4daa84c1bf61500ac399",
 | 
						||
      "af985cf6fa26475eb2c4dd81e0c79ff4",
 | 
						||
      "8659c3eddb014c3bb5931fd9e6fadad8",
 | 
						||
      "f5fa00d96c4c49e48e1806d23a5b8570",
 | 
						||
      "080c484114f64f5591fa1287a35b46c9",
 | 
						||
      "14dc6a3717484c55a116612e28447dbb",
 | 
						||
      "00d3286c9c1d4161bb777b7b65ae744d",
 | 
						||
      "66f27fb11edf453b8144c2dfcdc66baa",
 | 
						||
      "5798e5118430439fb1f6bf29e1bafe58",
 | 
						||
      "357f367cf74146b8825be371acd51d06",
 | 
						||
      "94073be250cd42d5b82e196e30cbf22e",
 | 
						||
      "0cd0724f825e480389a82f0c49f91e6d",
 | 
						||
      "dffa208978f34e6a9aae94ecda92fe67",
 | 
						||
      "b8a98f163ebd4ac89af08a49c0881c23",
 | 
						||
      "f0d9febe1a634a0ba7e8e50fa104dcc2",
 | 
						||
      "e23870f0c7ff40cc8fa6a1e862a4af99",
 | 
						||
      "87da9905a0534c26ad0712ad426ca930",
 | 
						||
      "b953419300604b8e86fc0ad003fdfd2f",
 | 
						||
      "f1865ed0fbcc40eeabdca90a43d00069",
 | 
						||
      "ea0128909a9d4801ba312a876b0cf183",
 | 
						||
      "d160986df978416c9ad91d1e10fc90fc",
 | 
						||
      "5e97f7c2e8f5453dafcdad0552060e60",
 | 
						||
      "4b3e7b8774df4b458bb6c6146fe3226d",
 | 
						||
      "2ffd8dbed00e46d2887b9a2590cad297",
 | 
						||
      "a06dcb3bdfc84905a7222066c32fe500",
 | 
						||
      "e7602abc26714ee890a0cf5c0c7b67e1",
 | 
						||
      "dc5d555099f64a998514ebde90eeb6df",
 | 
						||
      "ef93a2f58cc54373941f43658bb808cf",
 | 
						||
      "fea1e2327d2944859af3d91c216b9008",
 | 
						||
      "320c00a5d18c45ccae634d166f1bd810",
 | 
						||
      "6c857e69d5204cd3b7c3bf426993ad1f",
 | 
						||
      "2145e47428f1446fba3e62b3cde0a7f5",
 | 
						||
      "3d519ce3562c4e249bf392c7f43d04c0",
 | 
						||
      "cc20ffcf0c1a4656945959bf457dfd84"
 | 
						||
     ]
 | 
						||
    },
 | 
						||
    "id": "u4J7IxOvOyPM",
 | 
						||
    "outputId": "925348d7-fc69-4d1b-90f1-7029426bcfcf"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "eabfde3ef38b436ea750e6fb50a02b5c",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "e117ad45771747ae95c16f9876e6dc19",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "170185f2f046437dab57c2ad23163c5c",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "6e65f5d6c5af4ab78bc7b3778b98ef86",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "combined_weights = {}\n",
 | 
						||
    "\n",
 | 
						||
    "for i in range(1, 5):\n",
 | 
						||
    "    weights_file = hf_hub_download(\n",
 | 
						||
    "        repo_id=\"meta-llama/Llama-3.1-8B\",\n",
 | 
						||
    "        filename=f\"model-0000{i}-of-00004.safetensors\",\n",
 | 
						||
    "        local_dir=\"Llama-3.1-8B\"\n",
 | 
						||
    "    )\n",
 | 
						||
    "    current_weights = load_file(weights_file)\n",
 | 
						||
    "    combined_weights.update(current_weights)\n",
 | 
						||
    "\n",
 | 
						||
    "load_weights_into_llama(model, LLAMA31_CONFIG_8B, combined_weights)\n",
 | 
						||
    "model.to(device);\n",
 | 
						||
    "del combined_weights  # free up memory"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 42,
 | 
						||
   "id": "wJFnF8ATPbtD",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "wJFnF8ATPbtD",
 | 
						||
    "outputId": "67d5cb66-3588-4fd4-ac75-39bfe3aa82d8"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Output text:\n",
 | 
						||
      " Every effort has been made to trace copyright holders and to obtain their permission for the use of copyright material. The publisher apologizes for any\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "torch.manual_seed(123)\n",
 | 
						||
    "\n",
 | 
						||
    "token_ids = generate(\n",
 | 
						||
    "    model=model,\n",
 | 
						||
    "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
 | 
						||
    "    max_new_tokens=25,\n",
 | 
						||
    "    context_size=LLAMA31_CONFIG_8B[\"context_length\"],\n",
 | 
						||
    "    top_k=1,\n",
 | 
						||
    "    temperature=0.\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "DR9NBDUjPrDp",
 | 
						||
   "metadata": {
 | 
						||
    "id": "DR9NBDUjPrDp"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "# Llama 3.2 1B"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "imoxFiDzJcxk",
 | 
						||
   "metadata": {
 | 
						||
    "id": "imoxFiDzJcxk"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- As of this writing, Meta AI's latest models are the Llama 3.2 models announced [here](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/)\n",
 | 
						||
    "- The code for the Llama 3.2 text model is similar to that of Llama 3.1, except that the model has shrunk in size (there is a 1B and 3B version)\n",
 | 
						||
    "- The other efficiency tweak was that they added back weight tying (a concept that was original used in the GPT-2 architecture); here, they reuse the same weight parameter values in the input (token) embedding layer and output layer\n",
 | 
						||
    "- The small model size of Llama 3.2 1B is quite convenient, since it can even run on many mobile devices\n",
 | 
						||
    "- The architectural differences between Llama 3.1 8B and Llama 3.2 1B are illustrated in the figure below"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "OL1EoXQ6TPb7",
 | 
						||
   "metadata": {
 | 
						||
    "id": "OL1EoXQ6TPb7"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/llama31-to-llama32.webp?1\" width=\"700px\">"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "K0KgjwCCJ9Fb",
 | 
						||
   "metadata": {
 | 
						||
    "id": "K0KgjwCCJ9Fb"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- As we can see based on the figure above, the main difference between the Llama 3.1 8B and Llama 3.2 1B architectures are the respective sizes\n",
 | 
						||
    "- A small additional change is an increased RoPE rescaling factor, which is reflected in the configuration file below"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 43,
 | 
						||
   "id": "Yv_yF3NCQTBx",
 | 
						||
   "metadata": {
 | 
						||
    "id": "Yv_yF3NCQTBx"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "LLAMA31_CONFIG_8B = {\n",
 | 
						||
    "    \"vocab_size\": 128_256,      # Vocabulary size\n",
 | 
						||
    "    \"context_length\": 131_072,  # NEW: Larger supported context length\n",
 | 
						||
    "    \"emb_dim\": 4096,            # Embedding dimension\n",
 | 
						||
    "    \"n_heads\": 32,              # Number of attention heads\n",
 | 
						||
    "    \"n_layers\": 32,             # Number of layers\n",
 | 
						||
    "    \"hidden_dim\": 14_336,       # Size of the intermediate dimension in FeedForward\n",
 | 
						||
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
						||
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
						||
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usagey\n",
 | 
						||
    "    \"rope_freq\": {              # NEW: RoPE frequency scaling\n",
 | 
						||
    "        \"factor\": 8.0,\n",
 | 
						||
    "        \"low_freq_factor\": 1.0,\n",
 | 
						||
    "        \"high_freq_factor\": 4.0,\n",
 | 
						||
    "        \"original_context_length\": 8192,\n",
 | 
						||
    "    }\n",
 | 
						||
    "}\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "LLAMA32_CONFIG_1B = {\n",
 | 
						||
    "    \"vocab_size\": 128_256,      # Vocabulary size\n",
 | 
						||
    "    \"context_length\": 131_072,  # Context length\n",
 | 
						||
    "    \"emb_dim\": 2048,            # NEW: Half the embedding dimension\n",
 | 
						||
    "    \"n_heads\": 32,              # Number of attention heads\n",
 | 
						||
    "    \"n_layers\": 16,             # NEW: Half the number of layers\n",
 | 
						||
    "    \"hidden_dim\": 8192,         # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
 | 
						||
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
						||
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
						||
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
 | 
						||
    "    \"rope_freq\": {              # RoPE frequency scaling\n",
 | 
						||
    "        \"factor\": 32.0,         # NEW: Adjustment of the rescaling factor\n",
 | 
						||
    "        \"low_freq_factor\": 1.0,\n",
 | 
						||
    "        \"high_freq_factor\": 4.0,\n",
 | 
						||
    "        \"original_context_length\": 8192,\n",
 | 
						||
    "    }\n",
 | 
						||
    "}"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "b5cd351b-d883-460d-9cdc-47e15ddb884a",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 10,
 | 
						||
   "id": "73f001a6-7ae0-4204-aa83-a27a8878dfd2",
 | 
						||
   "metadata": {
 | 
						||
    "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "New RoPE theta: 31250.0\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "old_context_length = LLAMA32_CONFIG_1B[\"context_length\"]\n",
 | 
						||
    "LLAMA32_CONFIG_1B[\"context_length\"] = 8192\n",
 | 
						||
    "\n",
 | 
						||
    "LLAMA32_CONFIG_1B[\"rope_base\"] = rescale_theta(\n",
 | 
						||
    "    LLAMA32_CONFIG_1B[\"rope_base\"],\n",
 | 
						||
    "    old_context_length,\n",
 | 
						||
    "    LLAMA32_CONFIG_1B[\"context_length\"]\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "print(\"New RoPE theta:\", LLAMA32_CONFIG_1B[\"rope_base\"])"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "Dl4_0EoJKKYv",
 | 
						||
   "metadata": {
 | 
						||
    "id": "Dl4_0EoJKKYv"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- Below, we can reuse the code from the Llama 3.1 8B section to load the Llama 3.2 1B model\n",
 | 
						||
    "- Again, since the Llama 3.2 family is distinct from the Llama 3.1 family, you'd have to go to the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository and acknowledge the license terms for your Hugging Face access token to work for the download\n",
 | 
						||
    "- Tip: For simplicity, we only load the base model below, but there's also an instruction-finetuned version you can use by replacing `\"meta-llama/Llama-3.2-1B\"` with `\"meta-llama/Llama-3.2-1B-Instruct\"`"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 44,
 | 
						||
   "id": "tCstHgyRRD2x",
 | 
						||
   "metadata": {
 | 
						||
    "id": "tCstHgyRRD2x"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# free up memory\n",
 | 
						||
    "del model\n",
 | 
						||
    "\n",
 | 
						||
    "\n",
 | 
						||
    "gc.collect()  # Run Python garbage collector\n",
 | 
						||
    "\n",
 | 
						||
    "if torch.cuda.is_available():\n",
 | 
						||
    "    torch.cuda.empty_cache()"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 45,
 | 
						||
   "id": "jt8BKAHXRCPI",
 | 
						||
   "metadata": {
 | 
						||
    "id": "jt8BKAHXRCPI"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "tokenizer_file_path = hf_hub_download(\n",
 | 
						||
    "    repo_id=\"meta-llama/Llama-3.2-1B\",\n",
 | 
						||
    "    filename=\"original/tokenizer.model\",\n",
 | 
						||
    "    local_dir=\"Llama-3.2-1B\"\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "tokenizer = Tokenizer(tokenizer_file_path)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 46,
 | 
						||
   "id": "uf8KjasmRFSt",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "uf8KjasmRFSt",
 | 
						||
    "outputId": "4e718852-2aa1-4b5a-bec3-3d5f866a4038"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Total number of parameters: 1,498,482,688\n",
 | 
						||
      "\n",
 | 
						||
      "Total number of unique parameters: 1,235,814,400\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "model = Llama3Model(LLAMA32_CONFIG_1B)\n",
 | 
						||
    "\n",
 | 
						||
    "total_params = sum(p.numel() for p in model.parameters())\n",
 | 
						||
    "print(f\"Total number of parameters: {total_params:,}\")\n",
 | 
						||
    "\n",
 | 
						||
    "# Account for weight tying\n",
 | 
						||
    "total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
 | 
						||
    "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 47,
 | 
						||
   "id": "9FbCIYW7RIOe",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "9FbCIYW7RIOe",
 | 
						||
    "outputId": "35588405-e2e1-4871-a1db-1d4bcb852e49"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "c309c56a6cdf426e8ba7967b6a21864e",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Model uses weight tying.\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "weights_file = hf_hub_download(\n",
 | 
						||
    "    repo_id=\"meta-llama/Llama-3.2-1B\",\n",
 | 
						||
    "    filename=\"model.safetensors\",\n",
 | 
						||
    "    local_dir=\"Llama-3.2-1B\"\n",
 | 
						||
    ")\n",
 | 
						||
    "current_weights = load_file(weights_file)\n",
 | 
						||
    "\n",
 | 
						||
    "load_weights_into_llama(model, LLAMA32_CONFIG_1B, current_weights)\n",
 | 
						||
    "model.to(device);\n",
 | 
						||
    "del current_weights  # free up memory"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 48,
 | 
						||
   "id": "pPp5yjir6FYJ",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "pPp5yjir6FYJ",
 | 
						||
    "outputId": "6c8e79d2-0769-43a7-93b3-f04c030e1aac"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Weight tying: True\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 49,
 | 
						||
   "id": "3kh7yrw2W4qr",
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "id": "3kh7yrw2W4qr",
 | 
						||
    "outputId": "b7e66a17-57ec-4b0e-c4ff-8d9a6b8e6ea5"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Output text:\n",
 | 
						||
      " Every effort is made to ensure that the information on this website is accurate. However, we cannot guarantee that the information is accurate, complete\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "torch.manual_seed(123)\n",
 | 
						||
    "\n",
 | 
						||
    "token_ids = generate(\n",
 | 
						||
    "    model=model,\n",
 | 
						||
    "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
 | 
						||
    "    max_new_tokens=25,\n",
 | 
						||
    "    context_size=LLAMA32_CONFIG_1B[\"context_length\"],\n",
 | 
						||
    "    top_k=1,\n",
 | 
						||
    "    temperature=0.\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "VO4Qf0zyW1ZC",
 | 
						||
   "metadata": {
 | 
						||
    "id": "VO4Qf0zyW1ZC"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    " \n",
 | 
						||
    "# What's next?"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "id": "CjCewpo2XPAd",
 | 
						||
   "metadata": {
 | 
						||
    "id": "CjCewpo2XPAd"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "- This notebook concludes the conversion from GPT to Llama 3.2\n",
 | 
						||
    "- If you are interested in a more compact, standalone notebook, which only contains the Llama 3.2 code, check out the [standalone-llama32.ipynb](standalone-llama32.ipynb) notebook"
 | 
						||
   ]
 | 
						||
  }
 | 
						||
 ],
 | 
						||
 "metadata": {
 | 
						||
  "accelerator": "GPU",
 | 
						||
  "colab": {
 | 
						||
   "gpuType": "A100",
 | 
						||
   "provenance": []
 | 
						||
  },
 | 
						||
  "kernelspec": {
 | 
						||
   "display_name": "Python 3 (ipykernel)",
 | 
						||
   "language": "python",
 | 
						||
   "name": "python3"
 | 
						||
  },
 | 
						||
  "language_info": {
 | 
						||
   "codemirror_mode": {
 | 
						||
    "name": "ipython",
 | 
						||
    "version": 3
 | 
						||
   },
 | 
						||
   "file_extension": ".py",
 | 
						||
   "mimetype": "text/x-python",
 | 
						||
   "name": "python",
 | 
						||
   "nbconvert_exporter": "python",
 | 
						||
   "pygments_lexer": "ipython3",
 | 
						||
   "version": "3.10.16"
 | 
						||
  },
 | 
						||
  "widgets": {
 | 
						||
   "application/vnd.jupyter.widget-state+json": {
 | 
						||
    "00148825ce0248b7a23eb28e3eca6749": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "00d3286c9c1d4161bb777b7b65ae744d": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "02ad170019454fd096b37347de5c481d": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "05b502e1e3a9436297dafbb1ce7af722": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_25977b0d89084703ad787fe9208b5aad",
 | 
						||
       "IPY_MODEL_71a84ee5fc964ec89ff2832c84735cc2",
 | 
						||
       "IPY_MODEL_6aed783eccb942318e6384e253ad4924"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_84c34bfecda64391a609e19f131d51d4"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "080c484114f64f5591fa1287a35b46c9": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_94073be250cd42d5b82e196e30cbf22e",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_0cd0724f825e480389a82f0c49f91e6d",
 | 
						||
      "value": " 5.00G/5.00G [00:15<00:00, 326MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "0a08fb81165748748ccb080e6df0600f": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_9d49589118f5432cac49650251046429",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_f114549fe8ce49638a791ca2fecb2d89",
 | 
						||
      "value": "model-00003-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "0aa155b794a8426aa265f4a7670f43ad": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "0afc2d23514b45c9890b5d2ee4e6fa0b": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_e8b187b40ec14db3af17a380830a35bf",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_e94ca32eaa9f4714a3b05a5fdf24d02b",
 | 
						||
      "value": "model-00002-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "0cd0724f825e480389a82f0c49f91e6d": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "0d51fdc2c416474da04079db6579890f": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "100c1b15cc4046cea1147f657eb2d8d0": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "14dc6a3717484c55a116612e28447dbb": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "15ea8fcfe097471e8fc9502a162f5904": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "172c0c6955e1428b999dcb2d133704cd": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "1b02e0c7d1604b1c87a327c4c4f8b0e7": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "1bf7108774c34016a2193e2cd7639b7d": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "1cd5e07cad35450182004952de32c8e7": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "1da83719e47c4196b06f3aa32056b560": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "20ecac7c646b45938ed393cb20977c37": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "2145e47428f1446fba3e62b3cde0a7f5": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "25977b0d89084703ad787fe9208b5aad": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_20ecac7c646b45938ed393cb20977c37",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_ebe04aeaaac042aaaa0885992e45793d",
 | 
						||
      "value": "model-00004-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "279cffe683fe4e7383062162e07ed9ed": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "2a2ba3d065634484a932b8d3c212af56": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "2ffd8dbed00e46d2887b9a2590cad297": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_a06dcb3bdfc84905a7222066c32fe500",
 | 
						||
       "IPY_MODEL_e7602abc26714ee890a0cf5c0c7b67e1",
 | 
						||
       "IPY_MODEL_dc5d555099f64a998514ebde90eeb6df"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_ef93a2f58cc54373941f43658bb808cf"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "31d27bf34a74432f8e0dbfe9ecb76130": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_eb94612785e64552aea8674dc8647a93",
 | 
						||
      "max": 4915916176,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_279cffe683fe4e7383062162e07ed9ed",
 | 
						||
      "value": 4915916176
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "320c00a5d18c45ccae634d166f1bd810": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "357f367cf74146b8825be371acd51d06": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "3cea4b431147441a8d9bd872811d5974": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "3d519ce3562c4e249bf392c7f43d04c0": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "3da5d38bf3314d3eaa7cedebae41c076": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_3edd464991204b8690eae02f10b4cc00",
 | 
						||
      "max": 4999802720,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_ac1e34f4bd6c420bb6cc2fdde5f3ed4d",
 | 
						||
      "value": 4999802720
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "3edd464991204b8690eae02f10b4cc00": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "409470784b6346a981920350de4f6f28": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_9ba6a11ffd194bf9a0900f52a7ed4d4f",
 | 
						||
       "IPY_MODEL_acae8bbbb4a84ed49be72fecd11fb052",
 | 
						||
       "IPY_MODEL_e8a4b441281b4038bb0204d093411f68"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_bdf8b693821344fc97918e6cbc31c8bf"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "4b3e7b8774df4b458bb6c6146fe3226d": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "4f6595a392b244bd8e887935defc06f0": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_02ad170019454fd096b37347de5c481d",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_c52e0f34892b4daa84c1bf61500ac399",
 | 
						||
      "value": " 4.98G/4.98G [00:16<00:00, 316MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "519147a10b984befbd0f255f78c1f66a": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "55e6b727a4594078beb3853cc1891308": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_1cd5e07cad35450182004952de32c8e7",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_a63351a6715643378491ba831b3fb05d",
 | 
						||
      "value": " 5.00G/5.00G [00:16<00:00, 291MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "562e82438dbe41b793ff488b8447c5bf": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "5798e5118430439fb1f6bf29e1bafe58": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "58c9dec75a3346b1b787f88dd510d254": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "5a2886564d3f40ceaa30b743dbe81f45": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "5bbaa046d8934c8fae0a12c3d7bd991b": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_e1e4125eac004bae92dc1f22f673bf0e",
 | 
						||
       "IPY_MODEL_d5b4bb4891ec4e44be46e9815c7e10dc",
 | 
						||
       "IPY_MODEL_4f6595a392b244bd8e887935defc06f0"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_100c1b15cc4046cea1147f657eb2d8d0"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "5e97f7c2e8f5453dafcdad0552060e60": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "6022a9426683420690d9b41a0ca4f870": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_9492edc02dee456f840325d913fa4e4f",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_66dc94b23556499f985f8accbb1f89cb",
 | 
						||
      "value": "model-00001-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "603690f543114a7fb6aebd433c80bdc3": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_0aa155b794a8426aa265f4a7670f43ad",
 | 
						||
      "max": 4915916176,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_a06fbde549cc47fdaddfbdb82d35d823",
 | 
						||
      "value": 4915916176
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "616e383bb3d442bcb6edb2721a8180b6": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_8ba9f009e92a46fcbcbb401dc444f12e",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_d74186bb74d142dfb683fa347b6990f7",
 | 
						||
      "value": " 5.00G/5.00G [00:16<00:00, 305MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "6176990205cc499f8995c71fc6b9d4df": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "66c23ae98bcc45f18fc5c91e0e73c3e4": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "66dc94b23556499f985f8accbb1f89cb": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "66f27fb11edf453b8144c2dfcdc66baa": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "6aed783eccb942318e6384e253ad4924": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_7015bf6f85954036aaf8cc4f1c44ea0f",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_2a2ba3d065634484a932b8d3c212af56",
 | 
						||
      "value": " 1.17G/1.17G [00:04<00:00, 297MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "6c857e69d5204cd3b7c3bf426993ad1f": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "7015bf6f85954036aaf8cc4f1c44ea0f": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "71a84ee5fc964ec89ff2832c84735cc2": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_ca81071ab07446df96795a482ce0c630",
 | 
						||
      "max": 1168138808,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_e0550cab24c7492787af40dc4b8576bf",
 | 
						||
      "value": 1168138808
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "7266a729edfb4a44b5b1c67dc79be146": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "76dbab4873f342019c5d7624ae2c9775": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "773b802daed942f5a11f3eab3b83be08": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_172c0c6955e1428b999dcb2d133704cd",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_1bf7108774c34016a2193e2cd7639b7d",
 | 
						||
      "value": " 4.92G/4.92G [00:16<00:00, 297MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "77606cd2fe1b4d33a91ede944bb1dec0": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "7989003a613e45f780d3f800e121543a": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "7c6658cfff1a4d27af3de148184f77d9": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "80dfd3e80ceb444a83ec1fd65f9af80e": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "81458e7953a349cfafccaa213b370406": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "84c34bfecda64391a609e19f131d51d4": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "8659c3eddb014c3bb5931fd9e6fadad8": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_00d3286c9c1d4161bb777b7b65ae744d",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_66f27fb11edf453b8144c2dfcdc66baa",
 | 
						||
      "value": "model-00002-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "87da9905a0534c26ad0712ad426ca930": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "87f474861e54432e9d533e0a89bb77da": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "8ae98969541849efa356cf912ac39b1e": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_f9373112649945e3b446c3e1ec274dc1",
 | 
						||
       "IPY_MODEL_d49791082a304ade95c185c79fae1f41",
 | 
						||
       "IPY_MODEL_616e383bb3d442bcb6edb2721a8180b6"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_87f474861e54432e9d533e0a89bb77da"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "8ba9f009e92a46fcbcbb401dc444f12e": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "94073be250cd42d5b82e196e30cbf22e": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "9492edc02dee456f840325d913fa4e4f": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "97e8877869cd4be68ff38ce745be5045": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "98b4680141ee423bb5e43c47613d8440": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_b02ffefca3f34252914e76f4a8a467dc",
 | 
						||
       "IPY_MODEL_31d27bf34a74432f8e0dbfe9ecb76130",
 | 
						||
       "IPY_MODEL_a3137f3669b54e84be91010c9654d985"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_5a2886564d3f40ceaa30b743dbe81f45"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "9ba6a11ffd194bf9a0900f52a7ed4d4f": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_97e8877869cd4be68ff38ce745be5045",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_cc3da88e93c4499993b7bbb7d3064326",
 | 
						||
      "value": "model-00001-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "9bb60a5a3710463ebe3a17f8d2a446be": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_0a08fb81165748748ccb080e6df0600f",
 | 
						||
       "IPY_MODEL_603690f543114a7fb6aebd433c80bdc3",
 | 
						||
       "IPY_MODEL_773b802daed942f5a11f3eab3b83be08"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_7989003a613e45f780d3f800e121543a"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "9d49589118f5432cac49650251046429": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "a06dcb3bdfc84905a7222066c32fe500": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_fea1e2327d2944859af3d91c216b9008",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_320c00a5d18c45ccae634d166f1bd810",
 | 
						||
      "value": "model-00004-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "a06fbde549cc47fdaddfbdb82d35d823": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "a3137f3669b54e84be91010c9654d985": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_6176990205cc499f8995c71fc6b9d4df",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_66c23ae98bcc45f18fc5c91e0e73c3e4",
 | 
						||
      "value": " 4.92G/4.92G [00:16<00:00, 297MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "a3dc9dfadae642b4a873705596739468": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "a63351a6715643378491ba831b3fb05d": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "ac1e34f4bd6c420bb6cc2fdde5f3ed4d": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "acae8bbbb4a84ed49be72fecd11fb052": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_0d51fdc2c416474da04079db6579890f",
 | 
						||
      "max": 4976698672,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_c4598300a77b4667b1117f9499f5ccb7",
 | 
						||
      "value": 4976698672
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "af985cf6fa26475eb2c4dd81e0c79ff4": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_8659c3eddb014c3bb5931fd9e6fadad8",
 | 
						||
       "IPY_MODEL_f5fa00d96c4c49e48e1806d23a5b8570",
 | 
						||
       "IPY_MODEL_080c484114f64f5591fa1287a35b46c9"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_14dc6a3717484c55a116612e28447dbb"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "b02ffefca3f34252914e76f4a8a467dc": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_15ea8fcfe097471e8fc9502a162f5904",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_c779e80c50ba4434bfa1d326c5cc9b0f",
 | 
						||
      "value": "model-00003-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "b46a08cf4929422eb0f76d8d9af11249": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_1da83719e47c4196b06f3aa32056b560",
 | 
						||
      "max": 1168138808,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_c4a2c88326d14fbca87cfde073755a2e",
 | 
						||
      "value": 1168138808
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "b8a98f163ebd4ac89af08a49c0881c23": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_b953419300604b8e86fc0ad003fdfd2f",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_f1865ed0fbcc40eeabdca90a43d00069",
 | 
						||
      "value": "model-00003-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "b953419300604b8e86fc0ad003fdfd2f": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "bdf8b693821344fc97918e6cbc31c8bf": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "be3e9bf271f04eb0b119659e1af3a0ea": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "c4598300a77b4667b1117f9499f5ccb7": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "c4a2c88326d14fbca87cfde073755a2e": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "c52e0f34892b4daa84c1bf61500ac399": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "c779e80c50ba4434bfa1d326c5cc9b0f": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "ca81071ab07446df96795a482ce0c630": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "cc20ffcf0c1a4656945959bf457dfd84": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "cc3da88e93c4499993b7bbb7d3064326": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "d160986df978416c9ad91d1e10fc90fc": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "d49791082a304ade95c185c79fae1f41": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_00148825ce0248b7a23eb28e3eca6749",
 | 
						||
      "max": 4999802720,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_f1a9b0c2431640298a6c1b258298b12d",
 | 
						||
      "value": 4999802720
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "d598f094c3ce4daeab19fac8094cba7e": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_0afc2d23514b45c9890b5d2ee4e6fa0b",
 | 
						||
       "IPY_MODEL_3da5d38bf3314d3eaa7cedebae41c076",
 | 
						||
       "IPY_MODEL_55e6b727a4594078beb3853cc1891308"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_f17fa78263414ef8b414c7bf3ac03192"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "d5b4bb4891ec4e44be46e9815c7e10dc": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_f55b59efcefa4ad5955d082f4bf7c637",
 | 
						||
      "max": 4976698672,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_1b02e0c7d1604b1c87a327c4c4f8b0e7",
 | 
						||
      "value": 4976698672
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "d74186bb74d142dfb683fa347b6990f7": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "dc5d555099f64a998514ebde90eeb6df": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_3d519ce3562c4e249bf392c7f43d04c0",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_cc20ffcf0c1a4656945959bf457dfd84",
 | 
						||
      "value": " 1.17G/1.17G [00:03<00:00, 328MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "dffa208978f34e6a9aae94ecda92fe67": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_b8a98f163ebd4ac89af08a49c0881c23",
 | 
						||
       "IPY_MODEL_f0d9febe1a634a0ba7e8e50fa104dcc2",
 | 
						||
       "IPY_MODEL_e23870f0c7ff40cc8fa6a1e862a4af99"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_87da9905a0534c26ad0712ad426ca930"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "e0550cab24c7492787af40dc4b8576bf": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "e1e4125eac004bae92dc1f22f673bf0e": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_81458e7953a349cfafccaa213b370406",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_a3dc9dfadae642b4a873705596739468",
 | 
						||
      "value": "model-00001-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "e23870f0c7ff40cc8fa6a1e862a4af99": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_5e97f7c2e8f5453dafcdad0552060e60",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_4b3e7b8774df4b458bb6c6146fe3226d",
 | 
						||
      "value": " 4.92G/4.92G [00:20<00:00, 317MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "e7602abc26714ee890a0cf5c0c7b67e1": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_6c857e69d5204cd3b7c3bf426993ad1f",
 | 
						||
      "max": 1168138808,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_2145e47428f1446fba3e62b3cde0a7f5",
 | 
						||
      "value": 1168138808
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "e805bb6dfee34dab8870f4618d8bffdb": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "e8a4b441281b4038bb0204d093411f68": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_77606cd2fe1b4d33a91ede944bb1dec0",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_f1ba439c26d64c90af2f162c74348405",
 | 
						||
      "value": " 4.98G/4.98G [00:16<00:00, 296MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "e8b187b40ec14db3af17a380830a35bf": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "e94ca32eaa9f4714a3b05a5fdf24d02b": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "e9aba3d53b4d45c485a7aad649c7b465": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_7c6658cfff1a4d27af3de148184f77d9",
 | 
						||
      "max": 4976698672,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_7266a729edfb4a44b5b1c67dc79be146",
 | 
						||
      "value": 4976698672
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "ea0128909a9d4801ba312a876b0cf183": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "eb94612785e64552aea8674dc8647a93": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "ebe04aeaaac042aaaa0885992e45793d": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "ed28e180d94a4b7aa548581612e31232": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_ff4338faded5494da1ccb660e1c441ed",
 | 
						||
       "IPY_MODEL_b46a08cf4929422eb0f76d8d9af11249",
 | 
						||
       "IPY_MODEL_f049eb4a50f54c34912ca959d2eaf353"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_80dfd3e80ceb444a83ec1fd65f9af80e"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "ef93a2f58cc54373941f43658bb808cf": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f049eb4a50f54c34912ca959d2eaf353": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_f0ab5a46cbb0444c88ed137d8a95002b",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_f8f28ac0e149428f9fef42373c6a87d0",
 | 
						||
      "value": " 1.17G/1.17G [00:03<00:00, 307MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f0ab5a46cbb0444c88ed137d8a95002b": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f0d9febe1a634a0ba7e8e50fa104dcc2": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_ea0128909a9d4801ba312a876b0cf183",
 | 
						||
      "max": 4915916176,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_d160986df978416c9ad91d1e10fc90fc",
 | 
						||
      "value": 4915916176
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f114549fe8ce49638a791ca2fecb2d89": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f17fa78263414ef8b414c7bf3ac03192": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f1865ed0fbcc40eeabdca90a43d00069": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f1a12d7929db4309b9881853135359fc": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_76dbab4873f342019c5d7624ae2c9775",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_3cea4b431147441a8d9bd872811d5974",
 | 
						||
      "value": " 4.98G/4.98G [00:16<00:00, 309MB/s]"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f1a9b0c2431640298a6c1b258298b12d": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "ProgressStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "ProgressStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "bar_color": null,
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f1ba439c26d64c90af2f162c74348405": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f3788acce34f4956b0727b58d0cf38c6": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HBoxModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HBoxModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HBoxView",
 | 
						||
      "box_style": "",
 | 
						||
      "children": [
 | 
						||
       "IPY_MODEL_6022a9426683420690d9b41a0ca4f870",
 | 
						||
       "IPY_MODEL_e9aba3d53b4d45c485a7aad649c7b465",
 | 
						||
       "IPY_MODEL_f1a12d7929db4309b9881853135359fc"
 | 
						||
      ],
 | 
						||
      "layout": "IPY_MODEL_58c9dec75a3346b1b787f88dd510d254"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f55b59efcefa4ad5955d082f4bf7c637": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f5fa00d96c4c49e48e1806d23a5b8570": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "FloatProgressModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "FloatProgressModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "ProgressView",
 | 
						||
      "bar_style": "success",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_5798e5118430439fb1f6bf29e1bafe58",
 | 
						||
      "max": 4999802720,
 | 
						||
      "min": 0,
 | 
						||
      "orientation": "horizontal",
 | 
						||
      "style": "IPY_MODEL_357f367cf74146b8825be371acd51d06",
 | 
						||
      "value": 4999802720
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f8f28ac0e149428f9fef42373c6a87d0": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "DescriptionStyleModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "DescriptionStyleModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "StyleView",
 | 
						||
      "description_width": ""
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "f9373112649945e3b446c3e1ec274dc1": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_e805bb6dfee34dab8870f4618d8bffdb",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_be3e9bf271f04eb0b119659e1af3a0ea",
 | 
						||
      "value": "model-00002-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "fea1e2327d2944859af3d91c216b9008": {
 | 
						||
     "model_module": "@jupyter-widgets/base",
 | 
						||
     "model_module_version": "1.2.0",
 | 
						||
     "model_name": "LayoutModel",
 | 
						||
     "state": {
 | 
						||
      "_model_module": "@jupyter-widgets/base",
 | 
						||
      "_model_module_version": "1.2.0",
 | 
						||
      "_model_name": "LayoutModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/base",
 | 
						||
      "_view_module_version": "1.2.0",
 | 
						||
      "_view_name": "LayoutView",
 | 
						||
      "align_content": null,
 | 
						||
      "align_items": null,
 | 
						||
      "align_self": null,
 | 
						||
      "border": null,
 | 
						||
      "bottom": null,
 | 
						||
      "display": null,
 | 
						||
      "flex": null,
 | 
						||
      "flex_flow": null,
 | 
						||
      "grid_area": null,
 | 
						||
      "grid_auto_columns": null,
 | 
						||
      "grid_auto_flow": null,
 | 
						||
      "grid_auto_rows": null,
 | 
						||
      "grid_column": null,
 | 
						||
      "grid_gap": null,
 | 
						||
      "grid_row": null,
 | 
						||
      "grid_template_areas": null,
 | 
						||
      "grid_template_columns": null,
 | 
						||
      "grid_template_rows": null,
 | 
						||
      "height": null,
 | 
						||
      "justify_content": null,
 | 
						||
      "justify_items": null,
 | 
						||
      "left": null,
 | 
						||
      "margin": null,
 | 
						||
      "max_height": null,
 | 
						||
      "max_width": null,
 | 
						||
      "min_height": null,
 | 
						||
      "min_width": null,
 | 
						||
      "object_fit": null,
 | 
						||
      "object_position": null,
 | 
						||
      "order": null,
 | 
						||
      "overflow": null,
 | 
						||
      "overflow_x": null,
 | 
						||
      "overflow_y": null,
 | 
						||
      "padding": null,
 | 
						||
      "right": null,
 | 
						||
      "top": null,
 | 
						||
      "visibility": null,
 | 
						||
      "width": null
 | 
						||
     }
 | 
						||
    },
 | 
						||
    "ff4338faded5494da1ccb660e1c441ed": {
 | 
						||
     "model_module": "@jupyter-widgets/controls",
 | 
						||
     "model_module_version": "1.5.0",
 | 
						||
     "model_name": "HTMLModel",
 | 
						||
     "state": {
 | 
						||
      "_dom_classes": [],
 | 
						||
      "_model_module": "@jupyter-widgets/controls",
 | 
						||
      "_model_module_version": "1.5.0",
 | 
						||
      "_model_name": "HTMLModel",
 | 
						||
      "_view_count": null,
 | 
						||
      "_view_module": "@jupyter-widgets/controls",
 | 
						||
      "_view_module_version": "1.5.0",
 | 
						||
      "_view_name": "HTMLView",
 | 
						||
      "description": "",
 | 
						||
      "description_tooltip": null,
 | 
						||
      "layout": "IPY_MODEL_519147a10b984befbd0f255f78c1f66a",
 | 
						||
      "placeholder": "",
 | 
						||
      "style": "IPY_MODEL_562e82438dbe41b793ff488b8447c5bf",
 | 
						||
      "value": "model-00004-of-00004.safetensors: 100%"
 | 
						||
     }
 | 
						||
    }
 | 
						||
   }
 | 
						||
  }
 | 
						||
 },
 | 
						||
 "nbformat": 4,
 | 
						||
 "nbformat_minor": 5
 | 
						||
}
 |