{ "cells": [ { "cell_type": "markdown", "id": "0_xya1nyDHfY", "metadata": { "id": "0_xya1nyDHfY" }, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", "
\n", "
\n", "\n", "
" ] }, { "cell_type": "markdown", "id": "l62zIRRSBy_R", "metadata": { "id": "l62zIRRSBy_R" }, "source": [ "# Converting Llama 2 to Llama 3.2 From Scratch" ] }, { "cell_type": "markdown", "id": "aFmxTQbwCUMl", "metadata": { "id": "aFmxTQbwCUMl" }, "source": [ "- This is a follow-up notebook to [Converting a From-Scratch GPT Architecture to Llama 2](./converting-gpt-to-llama2.ipynb), converting Meta AI's Llama 2 architecture model step by step to Llama 3, Llama 3.1, and Llama 3.2\n", "- The explanations are purposefully kept minimal in this notebook so as not to bloat it unnecessarily and focus on the main code\n", "- For more information about the architectures, please see the Llama 2 and Llama 3 papers\n", " - [Llama 2: Open Foundation and Fine-Tuned Chat Models (2023)](https://arxiv.org/abs/2307.09288)\n", " - [The Llama 3 Herd of Models](https://arxiv.org/abs/2407.21783)" ] }, { "cell_type": "markdown", "id": "ohhMKUWvGm9z", "metadata": { "id": "ohhMKUWvGm9z" }, "source": [ "" ] }, { "cell_type": "code", "execution_count": 1, "id": "ws0wsUzwLH2k", "metadata": { "id": "ws0wsUzwLH2k" }, "outputs": [], "source": [ "# pip install -r requirements-extra.txt" ] }, { "cell_type": "markdown", "id": "JBpQwU89ETA1", "metadata": { "id": "JBpQwU89ETA1" }, "source": [ "- Packages that are being used in this notebook:" ] }, { "cell_type": "code", "execution_count": 2, "id": "34a9a440-84c2-42cc-808b-38677cb6af8a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "34a9a440-84c2-42cc-808b-38677cb6af8a", "outputId": "e3d3d4b6-ee63-4e28-d794-e8b0bdd931fd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "blobfile version: 3.0.0\n", "huggingface_hub version: 0.24.7\n", "tiktoken version: 0.8.0\n", "torch version: 2.4.1+cu121\n" ] } ], "source": [ "from importlib.metadata import version\n", "\n", "pkgs = [\n", " \"blobfile\", # to download pretrained weights\n", " \"huggingface_hub\", # to download pretrained weights\n", " \"tiktoken\", # to implement the tokenizer\n", " \"torch\", # to implement the model\n", "]\n", "for p in pkgs:\n", " print(f\"{p} version: {version(p)}\")" ] }, { "cell_type": "markdown", "id": "UJJneXpTEg4W", "metadata": { "id": "UJJneXpTEg4W" }, "source": [ " \n", "# 1. Convert the Llama model implementation step by step" ] }, { "cell_type": "markdown", "id": "v1zpfX2GHBKa", "metadata": { "id": "v1zpfX2GHBKa" }, "source": [ "- If you are new to implementing LLM architectures, I recommend starting with [chapter 4](../../ch04/01_main-chapter-code/ch04.ipynb), which walks you through the implementation of the original GPT architecture step by step\n", "- The [Converting a From-Scratch GPT Architecture to Llama 2](./converting-gpt-to-llama2.ipynb) then implements the Llama-specific components, such as RMSNorm layers, SiLU and SwiGLU activations, RoPE (rotary position embeddings), and the SentencePiece tokenizer\n", "- This notebook takes the Llama 2 architecture and transforms it into Llama 3 architecture by\n", " 1. modifying the rotary embeddings\n", " 2. implementing grouped-query attention\n", " 3. and using a customized version of the GPT-4 tokenizer\n", "- Later, we then load the original Llama 3 weights shared by Meta AI into the architecture" ] }, { "cell_type": "markdown", "id": "c14b9121-abe1-4a46-99b8-acdef71e5b41", "metadata": { "id": "c14b9121-abe1-4a46-99b8-acdef71e5b41" }, "source": [ " \n", "## 1.1 Reusing Llama 2 components" ] }, { "cell_type": "markdown", "id": "dgDhJGJ6xR4e", "metadata": { "id": "dgDhJGJ6xR4e" }, "source": [ "- Llama 2 is actually quite similar to Llama 3, as mentioned above and illustrated in the figure at the top of this notebook\n", "- This means that we can import several building blocks from the [Llama 2 notebook](./converting-gpt-to-llama2.ipynb) using the following code" ] }, { "cell_type": "code", "execution_count": 3, "id": "a5bc3948-231b-4f1f-8d41-24ad0b7643d0", "metadata": { "id": "a5bc3948-231b-4f1f-8d41-24ad0b7643d0" }, "outputs": [], "source": [ "import os\n", "import sys\n", "import io\n", "import nbformat\n", "import types\n", "\n", "def import_from_notebook():\n", " def import_definitions_from_notebook(fullname, names):\n", " current_dir = os.getcwd()\n", " path = os.path.join(current_dir, fullname + \".ipynb\")\n", " path = os.path.normpath(path)\n", "\n", " # Load the notebook\n", " if not os.path.exists(path):\n", " raise FileNotFoundError(f\"Notebook file not found at: {path}\")\n", "\n", " with io.open(path, \"r\", encoding=\"utf-8\") as f:\n", " nb = nbformat.read(f, as_version=4)\n", "\n", " # Create a module to store the imported functions and classes\n", " mod = types.ModuleType(fullname)\n", " sys.modules[fullname] = mod\n", "\n", " # Go through the notebook cells and only execute function or class definitions\n", " for cell in nb.cells:\n", " if cell.cell_type == \"code\":\n", " cell_code = cell.source\n", " for name in names:\n", " # Check for function or class definitions\n", " if f\"def {name}\" in cell_code or f\"class {name}\" in cell_code:\n", " exec(cell_code, mod.__dict__)\n", " return mod\n", "\n", " fullname = \"converting-gpt-to-llama2\"\n", " names = [\"precompute_rope_params\", \"compute_rope\", \"SiLU\", \"FeedForward\", \"RMSNorm\", \"MultiHeadAttention\"]\n", "\n", " return import_definitions_from_notebook(fullname, names)" ] }, { "cell_type": "code", "execution_count": 4, "id": "d546032d-fce4-47cf-8d0e-682b78b21c61", "metadata": { "id": "d546032d-fce4-47cf-8d0e-682b78b21c61" }, "outputs": [], "source": [ "imported_module = import_from_notebook()\n", "\n", "# We need to redefine precompute_rope_params\n", "# precompute_rope_params = getattr(imported_module, \"precompute_rope_params\", None)\n", "compute_rope = getattr(imported_module, \"compute_rope\", None)\n", "SiLU = getattr(imported_module, \"SiLU\", None)\n", "FeedForward = getattr(imported_module, \"FeedForward\", None)\n", "RMSNorm = getattr(imported_module, \"RMSNorm\", None)\n", "\n", "# MultiHeadAttention only for comparison purposes\n", "MultiHeadAttention = getattr(imported_module, \"MultiHeadAttention\", None)" ] }, { "cell_type": "markdown", "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f", "metadata": { "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f" }, "source": [ " \n", "## 1.2 Modified RoPE" ] }, { "cell_type": "markdown", "id": "m9_oDcHCx8VI", "metadata": { "id": "m9_oDcHCx8VI" }, "source": [ "- Llama 3 uses rotary position embeddings (RoPE) similar to Llama 2 (for a detailed explanation, please see the [RoPE paper](https://arxiv.org/abs/2104.09864))\n", "- There are some subtle differences in the RoPE settings, though\n", " - Llama 3 now supports up to 8,192 tokens, twice as many as Llama 2 (4,096)\n", " - The base value for the so-called RoPE $\\theta$ (see equation below) was increased from 10,000 (Llama 2) to 50,000 (Llama 3) in the following equation (adapted from the [RoPE paper](https://arxiv.org/abs/2104.09864))\n", "\n", "$$\\Theta = \\left\\{\\theta_i = \\text{base}^{\\frac{2(i-1)}{d}}, i \\in \\left[1, 2, ..., d/2\\right]\\right\\}$$\n", "\n", "- These $\\theta$ values are a set of predefined parameters that are used to determine the rotational angles in the rotary matrix, where $d$ is the dimensionality of the embedding space\n", "- Increasing the base from 10,000 to 50,000 makes the frequencies (or rotation angles) decay more slowly across the dimensions, which means that higher dimensions will be associated with larger angles than before (essentially, it's a decompression of the frequencies)\n", "- In addition, we introduce a `freq_config` section in the code below that adjusts the frequency; however, we won't be needing it in Llama 3 (only Llama 3.1 and Llama 3.2), so we will revisit this `freq_config` later (it's set to `None` and ignored by default)" ] }, { "cell_type": "code", "execution_count": 5, "id": "6Upl109OOAcu", "metadata": { "id": "6Upl109OOAcu" }, "outputs": [], "source": [ "import torch\n", "\n", "def precompute_rope_params(head_dim, theta_base=10000, context_length=4096, freq_config=None):\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", " # Compute the inverse frequencies\n", " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", "\n", " ################################ NEW ###############################################\n", " # Frequency adjustments\n", " if freq_config is not None:\n", " low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n", " high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n", "\n", " wavelen = 2 * torch.pi / inv_freq\n", "\n", " inv_freq_llama = torch.where(\n", " wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n", " )\n", "\n", " smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n", " freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n", " )\n", "\n", " smoothed_inv_freq = (\n", " (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n", " )\n", "\n", " is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n", " inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n", " inv_freq = inv_freq_llama\n", " ####################################################################################\n", "\n", "\n", " # Generate position indices\n", " positions = torch.arange(context_length)\n", "\n", " # Compute the angles\n", " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", "\n", " # Precompute sine and cosine\n", " cos = torch.cos(angles)\n", " sin = torch.sin(angles)\n", "\n", " return cos, sin" ] }, { "cell_type": "markdown", "id": "jJBvO0YMJBXR", "metadata": { "id": "jJBvO0YMJBXR" }, "source": [ "- To summarize, what's new so far for Llama 3 compared to Llama 2 are the context length and theta base parameter:" ] }, { "cell_type": "code", "execution_count": 6, "id": "56c37216-e022-4603-be16-f9d3eaeaf4a1", "metadata": { "id": "56c37216-e022-4603-be16-f9d3eaeaf4a1" }, "outputs": [], "source": [ "# Instantiate RoPE parameters\n", "\n", "llama_2_context_len = 4096\n", "llama_3_context_len = 8192\n", "\n", "llama_2_theta_base = 10_000\n", "llama_3_theta_base = 50_000" ] }, { "cell_type": "markdown", "id": "_V8v6i7MJItU", "metadata": { "id": "_V8v6i7MJItU" }, "source": [ "- The usage remains the same as before in Llama 2:" ] }, { "cell_type": "code", "execution_count": 7, "id": "dae70c8a-eb18-40f9-a2e5-a6af2a57628b", "metadata": { "id": "dae70c8a-eb18-40f9-a2e5-a6af2a57628b" }, "outputs": [], "source": [ "# Settings\n", "batch_size = 2\n", "num_heads = 4\n", "head_dim = 16\n", "\n", "# Instantiate RoPE parameters\n", "cos, sin = precompute_rope_params(\n", " head_dim=head_dim,\n", " theta_base=llama_3_theta_base,\n", " context_length=llama_3_context_len\n", ")\n", "\n", "# Dummy query and key tensors\n", "torch.manual_seed(123)\n", "queries = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n", "keys = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n", "\n", "# Apply rotary position embeddings\n", "queries_rot = compute_rope(queries, cos, sin)\n", "keys_rot = compute_rope(keys, cos, sin)" ] }, { "cell_type": "markdown", "id": "cd19b75c-cf25-47b8-a010-6733fc0e9a8a", "metadata": { "id": "cd19b75c-cf25-47b8-a010-6733fc0e9a8a" }, "source": [ " \n", "## 1.3 Grouped-query attention" ] }, { "cell_type": "markdown", "id": "111c7d3f-fded-49e8-a617-9fe67b81dddc", "metadata": { "id": "111c7d3f-fded-49e8-a617-9fe67b81dddc" }, "source": [ "- In this section, we replace multi-head attention (MHA) with an alternative mechanism called grouped-query attention (GQA)\n", "- In short, one can think of GQA as a more compute- and parameter-efficient version of MHA\n", "- In GQA, we reduce the number of key and value projections by sharing them among multiple attention heads\n", "- Each attention head still has its unique query, but these queries attend to the same group of keys and values\n", "- Below is an illustration of GQA with 2 key-value-groups (kv-groups):\n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "perAYa2R_KW2", "metadata": { "id": "perAYa2R_KW2" }, "source": [ "- The main idea behind GQA is to reduce the number of unique query groups that attend to the key-value pairs, reducing the size of some of the matrix multiplications and the number of parameters in MHA without significantly reducing modeling performance\n", "- The GQA code is very similar to MHA (I highlighted the changes below via the \"NEW\" sections)\n", "- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below" ] }, { "cell_type": "code", "execution_count": 8, "id": "9b12e674-ef08-4dd7-8843-615b65b39c91", "metadata": { "id": "9b12e674-ef08-4dd7-8843-615b65b39c91" }, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class GroupedQueryAttention(nn.Module):\n", " def __init__(\n", " self, d_in, d_out, context_length, num_heads,\n", " num_kv_groups, # NEW\n", " rope_base=10_000, # NEW\n", " rope_config=None, # NEW\n", " dtype=None\n", " ):\n", " super().__init__()\n", " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", "\n", " self.d_out = d_out\n", " self.num_heads = num_heads\n", " self.head_dim = d_out // num_heads\n", "\n", " ############################# NEW #############################\n", " # self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " # self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", " self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", " self.num_kv_groups = num_kv_groups\n", " self.group_size = num_heads // num_kv_groups\n", " ################################################################\n", "\n", " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", "\n", " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", " cos, sin = precompute_rope_params(\n", " head_dim=self.head_dim,\n", " theta_base=rope_base, # NEW\n", " freq_config=rope_config, # NEW\n", " context_length=8192\n", " )\n", " self.register_buffer(\"cos\", cos)\n", " self.register_buffer(\"sin\", sin)\n", "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape\n", "\n", " queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n", " keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", " values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", "\n", " # Reshape queries, keys, and values\n", " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n", "\n", " ##################### NEW #####################\n", " # keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n", " # values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n", " keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", " ################################################\n", "\n", " # Transpose keys, values, and queries\n", " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", "\n", " # Apply RoPE\n", " keys = compute_rope(keys, self.cos, self.sin)\n", " queries = compute_rope(queries, self.cos, self.sin)\n", "\n", " ##################### NEW #####################\n", " # Expand keys and values to match the number of heads\n", " # Shape: (b, num_heads, num_tokens, head_dim)\n", "\n", " keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", " values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", " # For example, before repeat_interleave along dim=1 (query groups):\n", " # [K1, K2]\n", " # After repeat_interleave (each query group is repeated group_size times):\n", " # [K1, K1, K2, K2]\n", " # If we used regular repeat instead of repeat_interleave, we'd get:\n", " # [K1, K2, K1, K2]\n", " ################################################\n", "\n", " # Compute scaled dot-product attention (aka self-attention) with a causal mask\n", " # Shape: (b, num_heads, num_tokens, num_tokens)\n", " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", "\n", " # Original mask truncated to the number of tokens and converted to boolean\n", " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", "\n", " # Use the mask to fill attention scores\n", " attn_scores.masked_fill_(mask_bool, -torch.inf)\n", "\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " assert keys.shape[-1] == self.head_dim\n", "\n", " # Shape: (b, num_tokens, num_heads, head_dim)\n", " context_vec = (attn_weights @ values).transpose(1, 2)\n", "\n", " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", " context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n", " context_vec = self.out_proj(context_vec) # optional projection\n", "\n", " return context_vec" ] }, { "cell_type": "markdown", "id": "roAXSwJs9hR8", "metadata": { "id": "roAXSwJs9hR8" }, "source": [ "- To illustrate the parameter savings, consider the following multi-head attention example from the GPT and Llama 2 code:" ] }, { "cell_type": "code", "execution_count": 9, "id": "b4b8f085-349e-4674-a3f0-78fde0664fac", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "b4b8f085-349e-4674-a3f0-78fde0664fac", "outputId": "9da09d72-43b1-45af-d46f-6928ea4af33a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "W_key: torch.Size([4096, 4096])\n", "W_value: torch.Size([4096, 4096])\n", "W_query: torch.Size([4096, 4096])\n" ] } ], "source": [ "# Settings\n", "batch_size = 1\n", "context_len = 3000\n", "max_context_len = 8192\n", "embed_dim = 4096\n", "num_heads = 32\n", "\n", "\n", "example_batch = torch.randn((batch_size, context_len, embed_dim))\n", "\n", "mha = MultiHeadAttention(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " context_length=max_context_len,\n", " num_heads=num_heads\n", ")\n", "\n", "mha(example_batch)\n", "\n", "print(\"W_key:\", mha.W_key.weight.shape)\n", "print(\"W_value:\", mha.W_value.weight.shape)\n", "print(\"W_query:\", mha.W_query.weight.shape)" ] }, { "cell_type": "markdown", "id": "IMQtFkcQ9sXC", "metadata": { "id": "IMQtFkcQ9sXC" }, "source": [ "- Now, if we use grouped-query attention instead, with 8 kv-groups (that's how many Llama 3 8B uses), we can see that the number of rows of the key and value matrices are reduced by a factor of 4 (because 32 attention heads divided by 8 kv-groups is 4)" ] }, { "cell_type": "code", "execution_count": 10, "id": "15e65d3c-7b42-4ed3-bfee-bb09578657bb", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "15e65d3c-7b42-4ed3-bfee-bb09578657bb", "outputId": "69709a78-2aaa-4597-8142-2f44eb59753f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "W_key: torch.Size([1024, 4096])\n", "W_value: torch.Size([1024, 4096])\n", "W_query: torch.Size([4096, 4096])\n" ] } ], "source": [ "gqa = GroupedQueryAttention(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " context_length=max_context_len,\n", " num_heads=num_heads,\n", " num_kv_groups=8,\n", " rope_base=llama_3_theta_base\n", ")\n", "\n", "gqa(example_batch)\n", "\n", "print(\"W_key:\", gqa.W_key.weight.shape)\n", "print(\"W_value:\", gqa.W_value.weight.shape)\n", "print(\"W_query:\", gqa.W_query.weight.shape)" ] }, { "cell_type": "markdown", "id": "1a5d4c88-c66a-483b-b4e2-419ff9fd60d5", "metadata": { "id": "1a5d4c88-c66a-483b-b4e2-419ff9fd60d5" }, "source": [ "- As a side note, to make the GroupedQueryAttention equivalent to standard multi-head attention, you can set the number of query groups (`num_kv_groups`) equal to the number of heads (`num_heads`)\n", "- Lastly, let's compare the number of parameters below:" ] }, { "cell_type": "code", "execution_count": 11, "id": "58f713aa-ac00-4e2f-8247-94609aa01350", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "58f713aa-ac00-4e2f-8247-94609aa01350", "outputId": "486dfd9c-9f3a-4b9e-f9a2-35fb43b9a5fb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of parameters:\n", "MHA: 67,108,864\n", "GQA: 41,943,040\n" ] } ], "source": [ "print(\"Total number of parameters:\")\n", "\n", "mha_total_params = sum(p.numel() for p in mha.parameters())\n", "print(f\"MHA: {mha_total_params:,}\")\n", "\n", "gqa_total_params = sum(p.numel() for p in gqa.parameters())\n", "print(f\"GQA: {gqa_total_params:,}\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "78b60dfd-6c0f-41f7-8f0c-8e57116f07f5", "metadata": { "id": "78b60dfd-6c0f-41f7-8f0c-8e57116f07f5" }, "outputs": [], "source": [ "# Free up memory:\n", "del mha\n", "del gqa" ] }, { "cell_type": "markdown", "id": "8fcd8802-2859-45a2-905a-f4fe96629dd9", "metadata": { "id": "8fcd8802-2859-45a2-905a-f4fe96629dd9" }, "source": [ " \n", "## 1.4 Update the TransformerBlock module" ] }, { "cell_type": "markdown", "id": "KABNccft_YnR", "metadata": { "id": "KABNccft_YnR" }, "source": [ "- Next, we update the `TransformerBlock`\n", "- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings" ] }, { "cell_type": "code", "execution_count": 13, "id": "f9fa8eb4-7196-4dee-aec6-0dcbc70921c4", "metadata": { "id": "f9fa8eb4-7196-4dee-aec6-0dcbc70921c4" }, "outputs": [], "source": [ "class TransformerBlock(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.att = GroupedQueryAttention( # MultiHeadAttention(\n", " d_in=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n", " context_length=cfg[\"context_length\"],\n", " num_heads=cfg[\"n_heads\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"], # NEW\n", " rope_base=cfg[\"rope_base\"], # NEW\n", " rope_config=cfg[\"rope_freq\"], # NEW\n", " dtype=cfg[\"dtype\"]\n", " )\n", " self.ff = FeedForward(cfg)\n", " self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", "\n", " def forward(self, x):\n", " # Shortcut connection for attention block\n", " shortcut = x\n", " x = self.norm1(x)\n", " x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n", " x = x + shortcut # Add the original input back\n", "\n", " # Shortcut connection for feed-forward block\n", " shortcut = x\n", " x = self.norm2(x)\n", " x = self.ff(x.to(torch.bfloat16))\n", " x = x + shortcut # Add the original input back\n", "\n", " return x" ] }, { "cell_type": "markdown", "id": "fd921ab5-c48c-4c52-bf41-b847b3b822b9", "metadata": { "id": "fd921ab5-c48c-4c52-bf41-b847b3b822b9" }, "source": [ " \n", "## 1.5 Defining the model class" ] }, { "cell_type": "markdown", "id": "M_tLAq_r_llN", "metadata": { "id": "M_tLAq_r_llN" }, "source": [ "- When setting up the model class, we fortunately don't have to do much; we just update the name to `Llama3Model`" ] }, { "cell_type": "code", "execution_count": 14, "id": "475755d6-01f7-4e6e-ad9a-cec6f031ebf6", "metadata": { "id": "475755d6-01f7-4e6e-ad9a-cec6f031ebf6" }, "outputs": [], "source": [ "# class Llama2Model(nn.Module):\n", "class Llama3Model(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", "\n", " self.trf_blocks = nn.Sequential(\n", " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", "\n", " self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", " def forward(self, in_idx):\n", " batch_size, seq_len = in_idx.shape\n", " tok_embeds = self.tok_emb(in_idx)\n", " x = tok_embeds\n", " x = self.trf_blocks(x)\n", " x = self.final_norm(x)\n", " logits = self.out_head(x.to(torch.bfloat16))\n", " return logits" ] }, { "cell_type": "markdown", "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60", "metadata": { "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60" }, "source": [ " \n", "## 2. Initialize model" ] }, { "cell_type": "markdown", "id": "HoGGRAGykQTE", "metadata": { "id": "HoGGRAGykQTE" }, "source": [ "- Now we can define a Llama 3 config file (the Llama 2 config file is shown for comparison)" ] }, { "cell_type": "code", "execution_count": 15, "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18", "metadata": { "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18" }, "outputs": [], "source": [ "LLAMA2_CONFIG_7B = {\n", " \"vocab_size\": 32_000, # Vocabulary size\n", " \"context_length\": 4096, # Context length\n", " \"emb_dim\": 4096, # Embedding dimension\n", " \"n_heads\": 32, # Number of attention heads\n", " \"n_layers\": 32, # Number of layers\n", " \"hidden_dim\": 11_008, # Size of the intermediate dimension in FeedForward\n", " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", "}" ] }, { "cell_type": "code", "execution_count": 16, "id": "2ad90f82-15c7-4806-b509-e45b56f57db5", "metadata": { "id": "2ad90f82-15c7-4806-b509-e45b56f57db5" }, "outputs": [], "source": [ "LLAMA3_CONFIG_8B = {\n", " \"vocab_size\": 128_256, # NEW: Larger vocabulary size\n", " \"context_length\": 8192, # NEW: Larger context length\n", " \"emb_dim\": 4096, # Embedding dimension\n", " \"n_heads\": 32, # Number of attention heads\n", " \"n_layers\": 32, # Number of layers\n", " \"hidden_dim\": 14_336, # NEW: Larger size of the intermediate dimension in FeedForward\n", " \"n_kv_groups\": 8, # NEW: Key-Value groups for grouped-query attention\n", " \"rope_base\": 50_000, # NEW: The base in RoPE's \"theta\" was increased to 50_000\n", " \"rope_freq\": None, # NEW: Additional configuration for adjusting the RoPE frequencies\n", " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", "}" ] }, { "cell_type": "markdown", "id": "FAP7fiBzkaBz", "metadata": { "id": "FAP7fiBzkaBz" }, "source": [ "- Using these settings, we can now initialize a Llama 3 8B model\n", "- Note that this requires ~34 GB of memory (for comparison, Llama 2 7B required ~26 GB of memory)" ] }, { "cell_type": "code", "execution_count": 17, "id": "7004d785-ac9a-4df5-8760-6807fc604686", "metadata": { "id": "7004d785-ac9a-4df5-8760-6807fc604686" }, "outputs": [], "source": [ "model = Llama3Model(LLAMA3_CONFIG_8B)" ] }, { "cell_type": "code", "execution_count": 18, "id": "6079f747-8f20-4c6b-8d38-7156f1101729", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6079f747-8f20-4c6b-8d38-7156f1101729", "outputId": "0a8cd23b-d9fa-4c2d-ca63-3fc79bc4de0d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of parameters: 8,030,261,248\n" ] } ], "source": [ "total_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Total number of parameters: {total_params:,}\")" ] }, { "cell_type": "markdown", "id": "Bx14NtzWk2wj", "metadata": { "id": "Bx14NtzWk2wj" }, "source": [ "- As shown above, the model contains 8 billion parameters\n", "- Additionally, we can calculate the memory requirements for this model using the code below:" ] }, { "cell_type": "code", "execution_count": 19, "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a", "outputId": "3425e9ce-d8c0-4b37-bded-a2c60b66a41a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "float32 (PyTorch default): 68.08 GB\n", "bfloat16: 34.04 GB\n" ] } ], "source": [ "def model_memory_size(model, input_dtype=torch.float32):\n", " total_params = 0\n", " total_grads = 0\n", " for param in model.parameters():\n", " # Calculate total number of elements per parameter\n", " param_size = param.numel()\n", " total_params += param_size\n", " # Check if gradients are stored for this parameter\n", " if param.requires_grad:\n", " total_grads += param_size\n", "\n", " # Calculate buffer size (non-parameters that require memory)\n", " total_buffers = sum(buf.numel() for buf in model.buffers())\n", "\n", " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n", " # We assume parameters and gradients are stored in the same type as input dtype\n", " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n", " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n", "\n", " # Convert bytes to gigabytes\n", " total_memory_gb = total_memory_bytes / (1024**3)\n", "\n", " return total_memory_gb\n", "\n", "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n", "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")" ] }, { "cell_type": "markdown", "id": "zudd-5PulKFL", "metadata": { "id": "zudd-5PulKFL" }, "source": [ "- Lastly, we can also transfer the model to an NVIDIA or Apple Silicon GPU if applicable:" ] }, { "cell_type": "code", "execution_count": 20, "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d", "metadata": { "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "elif torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", "else:\n", " device = torch.device(\"cpu\")\n", "\n", "model.to(device);" ] }, { "cell_type": "markdown", "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34", "metadata": { "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34" }, "source": [ " \n", "## 3. Load tokenizer" ] }, { "cell_type": "markdown", "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005", "metadata": { "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005" }, "source": [ "- In this section, we are going to load the tokenizer for the model\n", "- Llama 2 used Google's [SentencePiece](https://github.com/google/sentencepiece) tokenizer instead of OpenAI's BPE tokenizer based on the [Tiktoken](https://github.com/openai/tiktoken) library\n", "- Llama 3, however, reverted back to using the BPE tokenizer from Tiktoken; specifically, it uses the GPT-4 tokenizer with an extended vocabulary\n", "- You can find the original Tiktoken-adaptation by Meta AI [here](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py) in their official Llama 3 repository\n", "- Below, I rewrote the tokenizer code to make it more readable and minimal for this notebook (but the behavior should be similar)" ] }, { "cell_type": "code", "execution_count": 21, "id": "5f390cbf-8f92-46dc-afe3-d90b5affae10", "metadata": { "id": "5f390cbf-8f92-46dc-afe3-d90b5affae10" }, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "\n", "import tiktoken\n", "from tiktoken.load import load_tiktoken_bpe\n", "\n", "\n", "class Tokenizer:\n", " def __init__(self, model_path):\n", " assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n", " mergeable_ranks = load_tiktoken_bpe(model_path)\n", " num_base_tokens = len(mergeable_ranks)\n", "\n", " self.special_tokens = {\n", " \"<|begin_of_text|>\": 128000,\n", " \"<|end_of_text|>\": 128001,\n", " \"<|start_header_id|>\": 128006,\n", " \"<|end_header_id|>\": 128007,\n", " \"<|eot_id|>\": 128009,\n", " }\n", " self.special_tokens.update({\n", " f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n", " })\n", "\n", " self.model = tiktoken.Encoding(\n", " name=Path(model_path).name,\n", " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n", " mergeable_ranks=mergeable_ranks,\n", " special_tokens=self.special_tokens\n", " )\n", "\n", "\n", " def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n", " if bos:\n", " tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n", " else:\n", " tokens = []\n", "\n", " tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n", "\n", " if eos:\n", " tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n", " return tokens\n", "\n", " def decode(self, tokens):\n", " return self.model.decode(tokens)" ] }, { "cell_type": "markdown", "id": "0a1509f8-8778-4fec-ba32-14d95c646167", "metadata": { "id": "0a1509f8-8778-4fec-ba32-14d95c646167" }, "source": [ "- Meta AI shared the original Llama 3 model weights and tokenizer vocabulary on the Hugging Face Hub\n", "- We will first download the tokenizer vocabulary from the Hub and load it into the code above" ] }, { "cell_type": "markdown", "id": "KbnlzsbYmJU6", "metadata": { "id": "KbnlzsbYmJU6" }, "source": [ "- Please note that Meta AI requires that you accept the Llama 3 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) repository to accept the terms\n", "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n", "\n", "\n", "\n", "\n", "- Then, create and copy the access token so you can copy & paste it into the next code cell\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 22, "id": "3357a230-b678-4691-a238-257ee4e80185", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3357a230-b678-4691-a238-257ee4e80185", "outputId": "a3652def-ea7f-46fb-f293-2a59affb71a0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", "Token is valid (permission: read).\n", "Your token has been saved to /root/.cache/huggingface/token\n", "Login successful\n" ] } ], "source": [ "from huggingface_hub import login\n", "import json\n", "\n", "with open(\"config.json\", \"r\") as config_file:\n", " config = json.load(config_file)\n", " access_token = config[\"HF_ACCESS_TOKEN\"]\n", "\n", "login(token=access_token)" ] }, { "cell_type": "markdown", "id": "IxGh6ZYQo0VN", "metadata": { "id": "IxGh6ZYQo0VN" }, "source": [ "- After login via the access token, which is necessary to verify that we accepted the Llama 3 licensing terms, we can now download the tokenizer vocabulary:" ] }, { "cell_type": "code", "execution_count": 23, "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", "outputId": "c9836ba8-5176-4dd5-b618-6cc36fdbe1f0" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] } ], "source": [ "from huggingface_hub import hf_hub_download\n", "\n", "tokenizer_file_path = hf_hub_download(\n", " repo_id=\"meta-llama/Meta-Llama-3-8B\",\n", " filename=\"original/tokenizer.model\",\n", " local_dir=\"llama3-files\"\n", ")" ] }, { "cell_type": "markdown", "id": "F8BH1Nk0AYCS", "metadata": { "id": "F8BH1Nk0AYCS" }, "source": [ "- Note that for using Llama 3 files, we may need the `blobfile` package, which is used when handling datasets or models stored in cloud storage solutions like Google Cloud Storage (GCS), Azure Blob Storage, or Amazon S3\n", "- You can install this dependency by uncommenting and executing the `pip` command below\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "5dm6Oz7uAytV", "metadata": { "id": "5dm6Oz7uAytV" }, "outputs": [], "source": [ "# pip install blobfile" ] }, { "cell_type": "code", "execution_count": 25, "id": "8b8c0ce6-a6fb-4b8a-8de2-ee7bb7646fd0", "metadata": { "id": "8b8c0ce6-a6fb-4b8a-8de2-ee7bb7646fd0" }, "outputs": [], "source": [ "tokenizer = Tokenizer(tokenizer_file_path)" ] }, { "cell_type": "markdown", "id": "NVhmFeX3pT_M", "metadata": { "id": "NVhmFeX3pT_M" }, "source": [ "- We can now use the `generate` function to have the Llama 3 model generate new text:" ] }, { "cell_type": "code", "execution_count": 26, "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e", "outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output text:\n", " Every effort_dead aeros Ingredients başında.extension clangmissions.esp 사진 Ek Pars til DoctorsDaoеньostivan normal Ekized � Ekized � Ek rdr tık%,orgen>',\n", "\n" ] } ], "source": [ "from previous_chapters import generate, text_to_token_ids, token_ids_to_text\n", "\n", "\n", "torch.manual_seed(123)\n", "\n", "token_ids = generate(\n", " model=model,\n", " idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n", " max_new_tokens=30,\n", " context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n", " top_k=1,\n", " temperature=0.\n", ")\n", "\n", "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" ] }, { "cell_type": "markdown", "id": "93WTtAA5paYV", "metadata": { "id": "93WTtAA5paYV" }, "source": [ "- Of course, as we can see above, the text is nonsensical since we haven't trained the Llama 3 model yet\n", "- In the next section, instead of training it ourselves, which would cost tens to hundreds of thousands of dollars, we load the pretrained weights from Meta AI" ] }, { "cell_type": "markdown", "id": "f63cc248-1d27-4eb6-aa50-173b436652f8", "metadata": { "id": "f63cc248-1d27-4eb6-aa50-173b436652f8" }, "source": [ " \n", "## 4. Load pretrained weights" ] }, { "cell_type": "markdown", "id": "aKeN7rUfqZMI", "metadata": { "id": "aKeN7rUfqZMI" }, "source": [ "- We are loading the [\"meta-llama/Meta-Llama-3-8B\"](https://huggingface.co/meta-llama/Meta-Llama-3-8B) base model below, which is a simple text completion model before finetuning\n", "- Alternatively, you can load the instruction-finetuned and aligned [\"meta-llama/Meta-Llama-3-8B-Instruct\"](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model by modifying the string in the next code cell accordingly\n", "- Combined, the weight files are about 16 GB large" ] }, { "cell_type": "code", "execution_count": 27, "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145, "referenced_widgets": [ "f3788acce34f4956b0727b58d0cf38c6", "6022a9426683420690d9b41a0ca4f870", "e9aba3d53b4d45c485a7aad649c7b465", "f1a12d7929db4309b9881853135359fc", "58c9dec75a3346b1b787f88dd510d254", "9492edc02dee456f840325d913fa4e4f", "66dc94b23556499f985f8accbb1f89cb", "7c6658cfff1a4d27af3de148184f77d9", "7266a729edfb4a44b5b1c67dc79be146", "76dbab4873f342019c5d7624ae2c9775", "3cea4b431147441a8d9bd872811d5974", "8ae98969541849efa356cf912ac39b1e", "f9373112649945e3b446c3e1ec274dc1", "d49791082a304ade95c185c79fae1f41", "616e383bb3d442bcb6edb2721a8180b6", "87f474861e54432e9d533e0a89bb77da", "e805bb6dfee34dab8870f4618d8bffdb", "be3e9bf271f04eb0b119659e1af3a0ea", "00148825ce0248b7a23eb28e3eca6749", "f1a9b0c2431640298a6c1b258298b12d", "8ba9f009e92a46fcbcbb401dc444f12e", "d74186bb74d142dfb683fa347b6990f7", "9bb60a5a3710463ebe3a17f8d2a446be", "0a08fb81165748748ccb080e6df0600f", "603690f543114a7fb6aebd433c80bdc3", "773b802daed942f5a11f3eab3b83be08", "7989003a613e45f780d3f800e121543a", "9d49589118f5432cac49650251046429", "f114549fe8ce49638a791ca2fecb2d89", "0aa155b794a8426aa265f4a7670f43ad", "a06fbde549cc47fdaddfbdb82d35d823", "172c0c6955e1428b999dcb2d133704cd", "1bf7108774c34016a2193e2cd7639b7d", "ed28e180d94a4b7aa548581612e31232", "ff4338faded5494da1ccb660e1c441ed", "b46a08cf4929422eb0f76d8d9af11249", "f049eb4a50f54c34912ca959d2eaf353", "80dfd3e80ceb444a83ec1fd65f9af80e", "519147a10b984befbd0f255f78c1f66a", "562e82438dbe41b793ff488b8447c5bf", "1da83719e47c4196b06f3aa32056b560", "c4a2c88326d14fbca87cfde073755a2e", "f0ab5a46cbb0444c88ed137d8a95002b", "f8f28ac0e149428f9fef42373c6a87d0" ] }, "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4", "outputId": "c05118ce-9f81-41c8-a1f2-72caa932ae86" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f3788acce34f4956b0727b58d0cf38c6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00\"])\n", " tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n", " tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n", " tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n", " return tokens\n", "\n", " def encode(self, text):\n", " message = {\n", " \"role\": \"user\",\n", " \"content\": text\n", " }\n", "\n", " tokens = self.encode_header(message)\n", " tokens.extend(\n", " self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n", " )\n", " tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n", " return tokens\n", "\n", " def decode(self, token_ids):\n", " return self.tokenizer.decode(token_ids)\n", "\n", "\n", "chat_tokenizer = ChatFormat(tokenizer)" ] }, { "cell_type": "markdown", "id": "M-dkSNvwDttN", "metadata": { "id": "M-dkSNvwDttN" }, "source": [ "- The usage is as follows:" ] }, { "cell_type": "code", "execution_count": 34, "id": "nwBrTGTsUNhn", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nwBrTGTsUNhn", "outputId": "72a495b4-b872-429a-88ef-49a9b4577f0f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[128006, 882, 128007, 271, 9906, 4435, 0, 128009]\n" ] } ], "source": [ "token_ids = chat_tokenizer.encode(\"Hello World!\")\n", "print(token_ids)" ] }, { "cell_type": "code", "execution_count": 35, "id": "0fpmpVgYVTRZ", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "0fpmpVgYVTRZ", "outputId": "bb3e819a-112a-466c-ac51-5d14a9c3475b" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'<|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|>'" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode(token_ids)" ] }, { "cell_type": "markdown", "id": "Wo-aUGeKDvqq", "metadata": { "id": "Wo-aUGeKDvqq" }, "source": [ "- Let's now see the Llama 3 instruction model in action:" ] }, { "cell_type": "code", "execution_count": 36, "id": "ozGOBu6XOkEW", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ozGOBu6XOkEW", "outputId": "4f689c70-bed9-46f3-a52a-aea47b641283" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output text:\n", " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Here are some of the things llamas like to eat:\n", "\n", "1. Grass: Llamas love to graze on grass, especially in the spring and summer months.\n", "2. Hay: Hay is a staple in a llama's diet. They like to eat timothy hay, alfalfa hay, and other types of hay.\n", "3. Grains: Llamas may also be fed grains like oats, barley, and corn. However, grains should not make up more than 10% of a llama's diet.\n", "4. Fruits and vegetables: Llamas may enjoy fruits and vegetables as treats, such as apples,\n" ] } ], "source": [ "import re\n", "\n", "\n", "torch.manual_seed(123)\n", "\n", "token_ids = generate(\n", " model=model,\n", " idx=text_to_token_ids(\"What do llamas eat?\", chat_tokenizer).to(device),\n", " max_new_tokens=150,\n", " context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n", " top_k=1,\n", " temperature=0.\n", ")\n", "\n", "output_text = token_ids_to_text(token_ids, tokenizer)\n", "\n", "\n", "def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n", " # Find the index of the first occurrence of \"<|end_header_id|>\"\n", " index = text.find(header_end)\n", "\n", " if index != -1:\n", " # Return the substring starting after \"<|end_header_id|>\"\n", " return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace\n", " else:\n", " # If the token is not found, return the original text\n", " return text\n", "\n", "print(\"Output text:\\n\", clean_text(output_text))" ] }, { "cell_type": "markdown", "id": "2r5JKrO-ZOHK", "metadata": { "id": "2r5JKrO-ZOHK" }, "source": [ " \n", "# Llama 3.1 8B" ] }, { "cell_type": "markdown", "id": "QiQxX0XnP_iC", "metadata": { "id": "QiQxX0XnP_iC" }, "source": [ "- A few months after the initial Llama 3 release, Meta AI followed up with their Llama 3.1 suite of models (see the official [Introducing Llama 3.1: Our most capable models to date](https://ai.meta.com/blog/meta-llama-3-1/) announcement blog post for details)\n", "- Conveniently, we can reuse our previous Llama 3 code from above to implement Llama 3.1 8B\n", "\n", "\n", "\n", "- The architecture is identical, with the only change being a rescaling of the RoPE frequencies as indicated in the configuration file below\n", "\n" ] }, { "cell_type": "code", "execution_count": 37, "id": "X5Fg8XUHMv4M", "metadata": { "id": "X5Fg8XUHMv4M" }, "outputs": [], "source": [ "LLAMA3_CONFIG_8B = {\n", " \"vocab_size\": 128_256, # Vocabulary size\n", " \"context_length\": 8192, # Context length\n", " \"emb_dim\": 4096, # Embedding dimension\n", " \"n_heads\": 32, # Number of attention heads\n", " \"n_layers\": 32, # Number of layers\n", " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"rope_freq\": None, # Additional configuration for adjusting the RoPE frequencies\n", " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", "}\n", "\n", "LLAMA31_CONFIG_8B = {\n", " \"vocab_size\": 128_256, # Vocabulary size\n", " \"context_length\": 8192, # Context length\n", " \"emb_dim\": 4096, # Embedding dimension\n", " \"n_heads\": 32, # Number of attention heads\n", " \"n_layers\": 32, # Number of layers\n", " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", " \"rope_freq\": { # NEW: RoPE frequency scaling\n", " \"factor\": 8.0,\n", " \"low_freq_factor\": 1.0,\n", " \"high_freq_factor\": 4.0,\n", " \"original_context_length\": 8192,\n", " }\n", "}" ] }, { "cell_type": "markdown", "id": "xa3bpMDtTdBs", "metadata": { "id": "xa3bpMDtTdBs" }, "source": [ "- As we've seen in the code earlier, the RoPE method uses sinusoidal functions (sine and cosine) to embed positional information directly into the attention mechanism\n", "- In Llama 3.1, via the additional configuration, we introduce additional adjustments to the inverse frequency calculations\n", "- These adjustments influence how different frequency components contribute to the positional embeddings (a detailed explanation is a topic for another time)\n", "- Let's try out the Llama 3.1 model in practice; first, we clear out the old model to free up some GPU memory" ] }, { "cell_type": "code", "execution_count": 38, "id": "7dUtYnNUOqhL", "metadata": { "id": "7dUtYnNUOqhL" }, "outputs": [], "source": [ "# free up memory\n", "del model\n", "\n", "gc.collect() # Run Python garbage collector\n", "\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "id": "DbbVsll6TYWR", "metadata": { "id": "DbbVsll6TYWR" }, "source": [ "- Next, we download the tokenizer\n", "- Note that since the Llama 3.1 family is distinct from the Llama 3 family, you'd have to go to the [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) repository and acknowledge the license terms for your Hugging Face access token to work for the download\n", "- Tip: For simplicity, we only load the base model below, but there's also an instruction-finetuned version you can use by replacing `\"meta-llama/Llama-3.1-8B\"` with `\"meta-llama/Llama-3.1-8B-Instruct\"`" ] }, { "cell_type": "code", "execution_count": 39, "id": "8xDk4chtPNU4", "metadata": { "id": "8xDk4chtPNU4" }, "outputs": [], "source": [ "tokenizer_file_path = hf_hub_download(\n", " repo_id=\"meta-llama/Llama-3.1-8B\",\n", " filename=\"original/tokenizer.model\",\n", " local_dir=\"llama3-files\"\n", ")\n", "\n", "tokenizer = Tokenizer(tokenizer_file_path)" ] }, { "cell_type": "code", "execution_count": 40, "id": "a7l21VE4Otcs", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "a7l21VE4Otcs", "outputId": "3dd5cfba-bf3f-44d2-9be1-7cd42bfe4ba9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of parameters: 8,030,261,248\n" ] } ], "source": [ "model = Llama3Model(LLAMA31_CONFIG_8B)\n", "\n", "total_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Total number of parameters: {total_params:,}\")" ] }, { "cell_type": "code", "execution_count": 41, "id": "u4J7IxOvOyPM", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145, "referenced_widgets": [ "5bbaa046d8934c8fae0a12c3d7bd991b", "e1e4125eac004bae92dc1f22f673bf0e", "d5b4bb4891ec4e44be46e9815c7e10dc", "4f6595a392b244bd8e887935defc06f0", "100c1b15cc4046cea1147f657eb2d8d0", "81458e7953a349cfafccaa213b370406", "a3dc9dfadae642b4a873705596739468", "f55b59efcefa4ad5955d082f4bf7c637", "1b02e0c7d1604b1c87a327c4c4f8b0e7", "02ad170019454fd096b37347de5c481d", "c52e0f34892b4daa84c1bf61500ac399", "af985cf6fa26475eb2c4dd81e0c79ff4", "8659c3eddb014c3bb5931fd9e6fadad8", "f5fa00d96c4c49e48e1806d23a5b8570", "080c484114f64f5591fa1287a35b46c9", "14dc6a3717484c55a116612e28447dbb", "00d3286c9c1d4161bb777b7b65ae744d", "66f27fb11edf453b8144c2dfcdc66baa", "5798e5118430439fb1f6bf29e1bafe58", "357f367cf74146b8825be371acd51d06", "94073be250cd42d5b82e196e30cbf22e", "0cd0724f825e480389a82f0c49f91e6d", "dffa208978f34e6a9aae94ecda92fe67", "b8a98f163ebd4ac89af08a49c0881c23", "f0d9febe1a634a0ba7e8e50fa104dcc2", "e23870f0c7ff40cc8fa6a1e862a4af99", "87da9905a0534c26ad0712ad426ca930", "b953419300604b8e86fc0ad003fdfd2f", "f1865ed0fbcc40eeabdca90a43d00069", "ea0128909a9d4801ba312a876b0cf183", "d160986df978416c9ad91d1e10fc90fc", "5e97f7c2e8f5453dafcdad0552060e60", "4b3e7b8774df4b458bb6c6146fe3226d", "2ffd8dbed00e46d2887b9a2590cad297", "a06dcb3bdfc84905a7222066c32fe500", "e7602abc26714ee890a0cf5c0c7b67e1", "dc5d555099f64a998514ebde90eeb6df", "ef93a2f58cc54373941f43658bb808cf", "fea1e2327d2944859af3d91c216b9008", "320c00a5d18c45ccae634d166f1bd810", "6c857e69d5204cd3b7c3bf426993ad1f", "2145e47428f1446fba3e62b3cde0a7f5", "3d519ce3562c4e249bf392c7f43d04c0", "cc20ffcf0c1a4656945959bf457dfd84" ] }, "id": "u4J7IxOvOyPM", "outputId": "925348d7-fc69-4d1b-90f1-7029426bcfcf" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5bbaa046d8934c8fae0a12c3d7bd991b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00" ] }, { "cell_type": "markdown", "id": "K0KgjwCCJ9Fb", "metadata": { "id": "K0KgjwCCJ9Fb" }, "source": [ "- As we can see based on the figure above, the main difference between the Llama 3.1 8B and Llama 3.2 1B architectures are the respective sizes\n", "- A small additional change is an increased RoPE rescaling factor, which is reflected in the configuration file below" ] }, { "cell_type": "code", "execution_count": 43, "id": "Yv_yF3NCQTBx", "metadata": { "id": "Yv_yF3NCQTBx" }, "outputs": [], "source": [ "LLAMA31_CONFIG_8B = {\n", " \"vocab_size\": 128_256, # Vocabulary size\n", " \"context_length\": 8192, # Context length\n", " \"emb_dim\": 4096, # Embedding dimension\n", " \"n_heads\": 32, # Number of attention heads\n", " \"n_layers\": 32, # Number of layers\n", " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", " \"rope_freq\": { # RoPE frequency scaling\n", " \"factor\": 8.0,\n", " \"low_freq_factor\": 1.0,\n", " \"high_freq_factor\": 4.0,\n", " \"original_context_length\": 8192,\n", " }\n", "}\n", "\n", "\n", "LLAMA32_CONFIG_1B = {\n", " \"vocab_size\": 128_256, # Vocabulary size\n", " \"context_length\": 8192, # Context length\n", " \"emb_dim\": 2048, # NEW: Half the embedding dimension\n", " \"n_heads\": 32, # Number of attention heads\n", " \"n_layers\": 16, # NEW: Half the number of layers\n", " \"hidden_dim\": 8192, # NEW: Almopst half the size of the intermediate dimension in FeedForward\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", " \"rope_freq\": { # RoPE frequency scaling\n", " \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n", " \"low_freq_factor\": 1.0,\n", " \"high_freq_factor\": 4.0,\n", " \"original_context_length\": 8192,\n", " }\n", "}" ] }, { "cell_type": "markdown", "id": "Dl4_0EoJKKYv", "metadata": { "id": "Dl4_0EoJKKYv" }, "source": [ "- Below, we can reuse the code from the Llama 3.1 8B section to load the Llama 3.2 1B model\n", "- Again, since the Llama 3.2 family is distinct from the Llama 3.1 family, you'd have to go to the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository and acknowledge the license terms for your Hugging Face access token to work for the download\n", "- Tip: For simplicity, we only load the base model below, but there's also an instruction-finetuned version you can use by replacing `\"meta-llama/Llama-3.2-1B\"` with `\"meta-llama/Llama-3.2-1B-Instruct\"`" ] }, { "cell_type": "code", "execution_count": 44, "id": "tCstHgyRRD2x", "metadata": { "id": "tCstHgyRRD2x" }, "outputs": [], "source": [ "# free up memory\n", "del model\n", "\n", "\n", "gc.collect() # Run Python garbage collector\n", "\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 45, "id": "jt8BKAHXRCPI", "metadata": { "id": "jt8BKAHXRCPI" }, "outputs": [], "source": [ "tokenizer_file_path = hf_hub_download(\n", " repo_id=\"meta-llama/Llama-3.2-1B\",\n", " filename=\"original/tokenizer.model\",\n", " local_dir=\"llama32-files\"\n", ")\n", "\n", "tokenizer = Tokenizer(tokenizer_file_path)" ] }, { "cell_type": "code", "execution_count": 50, "id": "uf8KjasmRFSt", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uf8KjasmRFSt", "outputId": "4e718852-2aa1-4b5a-bec3-3d5f866a4038" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of parameters: 1,498,482,688\n", "\n", "Total number of unique parameters: 1,235,814,400\n" ] } ], "source": [ "model = Llama3Model(LLAMA32_CONFIG_1B)\n", "\n", "total_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Total number of parameters: {total_params:,}\")\n", "\n", "# Account for weight tying\n", "total_params_normalized = total_params - model.tok_emb.weight.numel()\n", "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")" ] }, { "cell_type": "code", "execution_count": 47, "id": "9FbCIYW7RIOe", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9FbCIYW7RIOe", "outputId": "35588405-e2e1-4871-a1db-1d4bcb852e49" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model uses weight tying.\n" ] } ], "source": [ "weights_file = hf_hub_download(\n", " repo_id=\"meta-llama/Llama-3.2-1B\",\n", " filename=f\"model.safetensors\",\n", " local_dir=\"llama32-files\"\n", ")\n", "current_weights = load_file(weights_file)\n", "\n", "load_weights_into_llama(model, LLAMA32_CONFIG_1B, current_weights)\n", "model.to(device);" ] }, { "cell_type": "code", "execution_count": 48, "id": "pPp5yjir6FYJ", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pPp5yjir6FYJ", "outputId": "6c8e79d2-0769-43a7-93b3-f04c030e1aac" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Weight tying: True\n" ] } ], "source": [ "print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))" ] }, { "cell_type": "code", "execution_count": 49, "id": "3kh7yrw2W4qr", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3kh7yrw2W4qr", "outputId": "b7e66a17-57ec-4b0e-c4ff-8d9a6b8e6ea5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output text:\n", " Every effort is made to ensure that the information on this website is accurate. However, we cannot guarantee that the information is accurate, complete\n" ] } ], "source": [ "torch.manual_seed(123)\n", "\n", "token_ids = generate(\n", " model=model,\n", " idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n", " max_new_tokens=25,\n", " context_size=LLAMA32_CONFIG_1B[\"context_length\"],\n", " top_k=1,\n", " temperature=0.\n", ")\n", "\n", "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" ] }, { "cell_type": "markdown", "id": "VO4Qf0zyW1ZC", "metadata": { "id": "VO4Qf0zyW1ZC" }, "source": [ " \n", "# What's next?" ] }, { "cell_type": "markdown", "id": "CjCewpo2XPAd", "metadata": { "id": "CjCewpo2XPAd" }, "source": [ "- This notebook concludes the conversion from GPT to Llama 3.2\n", "- If you are interested in a more compact, standalone notebook, which only contains the Llama 3.2 code, check out the [standalone-llama32.ipynb](standalone-llama32.ipynb) notebook" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "A100", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "00148825ce0248b7a23eb28e3eca6749": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "00d3286c9c1d4161bb777b7b65ae744d": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "02ad170019454fd096b37347de5c481d": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "05b502e1e3a9436297dafbb1ce7af722": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_25977b0d89084703ad787fe9208b5aad", "IPY_MODEL_71a84ee5fc964ec89ff2832c84735cc2", "IPY_MODEL_6aed783eccb942318e6384e253ad4924" ], "layout": "IPY_MODEL_84c34bfecda64391a609e19f131d51d4" } }, "080c484114f64f5591fa1287a35b46c9": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_94073be250cd42d5b82e196e30cbf22e", "placeholder": "​", "style": "IPY_MODEL_0cd0724f825e480389a82f0c49f91e6d", "value": " 5.00G/5.00G [00:15<00:00, 326MB/s]" } }, "0a08fb81165748748ccb080e6df0600f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_9d49589118f5432cac49650251046429", "placeholder": "​", "style": "IPY_MODEL_f114549fe8ce49638a791ca2fecb2d89", "value": "model-00003-of-00004.safetensors: 100%" } }, "0aa155b794a8426aa265f4a7670f43ad": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0afc2d23514b45c9890b5d2ee4e6fa0b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_e8b187b40ec14db3af17a380830a35bf", "placeholder": "​", "style": "IPY_MODEL_e94ca32eaa9f4714a3b05a5fdf24d02b", "value": "model-00002-of-00004.safetensors: 100%" } }, "0cd0724f825e480389a82f0c49f91e6d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "0d51fdc2c416474da04079db6579890f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "100c1b15cc4046cea1147f657eb2d8d0": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "14dc6a3717484c55a116612e28447dbb": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "15ea8fcfe097471e8fc9502a162f5904": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "172c0c6955e1428b999dcb2d133704cd": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1b02e0c7d1604b1c87a327c4c4f8b0e7": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "1bf7108774c34016a2193e2cd7639b7d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "1cd5e07cad35450182004952de32c8e7": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1da83719e47c4196b06f3aa32056b560": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "20ecac7c646b45938ed393cb20977c37": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2145e47428f1446fba3e62b3cde0a7f5": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "25977b0d89084703ad787fe9208b5aad": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_20ecac7c646b45938ed393cb20977c37", "placeholder": "​", "style": "IPY_MODEL_ebe04aeaaac042aaaa0885992e45793d", "value": "model-00004-of-00004.safetensors: 100%" } }, "279cffe683fe4e7383062162e07ed9ed": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "2a2ba3d065634484a932b8d3c212af56": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2ffd8dbed00e46d2887b9a2590cad297": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a06dcb3bdfc84905a7222066c32fe500", "IPY_MODEL_e7602abc26714ee890a0cf5c0c7b67e1", "IPY_MODEL_dc5d555099f64a998514ebde90eeb6df" ], "layout": "IPY_MODEL_ef93a2f58cc54373941f43658bb808cf" } }, "31d27bf34a74432f8e0dbfe9ecb76130": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_eb94612785e64552aea8674dc8647a93", "max": 4915916176, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_279cffe683fe4e7383062162e07ed9ed", "value": 4915916176 } }, "320c00a5d18c45ccae634d166f1bd810": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "357f367cf74146b8825be371acd51d06": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "3cea4b431147441a8d9bd872811d5974": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "3d519ce3562c4e249bf392c7f43d04c0": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3da5d38bf3314d3eaa7cedebae41c076": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_3edd464991204b8690eae02f10b4cc00", "max": 4999802720, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_ac1e34f4bd6c420bb6cc2fdde5f3ed4d", "value": 4999802720 } }, "3edd464991204b8690eae02f10b4cc00": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "409470784b6346a981920350de4f6f28": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_9ba6a11ffd194bf9a0900f52a7ed4d4f", "IPY_MODEL_acae8bbbb4a84ed49be72fecd11fb052", "IPY_MODEL_e8a4b441281b4038bb0204d093411f68" ], "layout": "IPY_MODEL_bdf8b693821344fc97918e6cbc31c8bf" } }, "4b3e7b8774df4b458bb6c6146fe3226d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "4f6595a392b244bd8e887935defc06f0": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_02ad170019454fd096b37347de5c481d", "placeholder": "​", "style": "IPY_MODEL_c52e0f34892b4daa84c1bf61500ac399", "value": " 4.98G/4.98G [00:16<00:00, 316MB/s]" } }, "519147a10b984befbd0f255f78c1f66a": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "55e6b727a4594078beb3853cc1891308": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_1cd5e07cad35450182004952de32c8e7", "placeholder": "​", "style": "IPY_MODEL_a63351a6715643378491ba831b3fb05d", "value": " 5.00G/5.00G [00:16<00:00, 291MB/s]" } }, "562e82438dbe41b793ff488b8447c5bf": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "5798e5118430439fb1f6bf29e1bafe58": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "58c9dec75a3346b1b787f88dd510d254": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5a2886564d3f40ceaa30b743dbe81f45": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5bbaa046d8934c8fae0a12c3d7bd991b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_e1e4125eac004bae92dc1f22f673bf0e", "IPY_MODEL_d5b4bb4891ec4e44be46e9815c7e10dc", "IPY_MODEL_4f6595a392b244bd8e887935defc06f0" ], "layout": "IPY_MODEL_100c1b15cc4046cea1147f657eb2d8d0" } }, "5e97f7c2e8f5453dafcdad0552060e60": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6022a9426683420690d9b41a0ca4f870": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_9492edc02dee456f840325d913fa4e4f", "placeholder": "​", "style": "IPY_MODEL_66dc94b23556499f985f8accbb1f89cb", "value": "model-00001-of-00004.safetensors: 100%" } }, "603690f543114a7fb6aebd433c80bdc3": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_0aa155b794a8426aa265f4a7670f43ad", "max": 4915916176, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_a06fbde549cc47fdaddfbdb82d35d823", "value": 4915916176 } }, "616e383bb3d442bcb6edb2721a8180b6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_8ba9f009e92a46fcbcbb401dc444f12e", "placeholder": "​", "style": "IPY_MODEL_d74186bb74d142dfb683fa347b6990f7", "value": " 5.00G/5.00G [00:16<00:00, 305MB/s]" } }, "6176990205cc499f8995c71fc6b9d4df": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "66c23ae98bcc45f18fc5c91e0e73c3e4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "66dc94b23556499f985f8accbb1f89cb": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "66f27fb11edf453b8144c2dfcdc66baa": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "6aed783eccb942318e6384e253ad4924": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7015bf6f85954036aaf8cc4f1c44ea0f", "placeholder": "​", "style": "IPY_MODEL_2a2ba3d065634484a932b8d3c212af56", "value": " 1.17G/1.17G [00:04<00:00, 297MB/s]" } }, "6c857e69d5204cd3b7c3bf426993ad1f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7015bf6f85954036aaf8cc4f1c44ea0f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "71a84ee5fc964ec89ff2832c84735cc2": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ca81071ab07446df96795a482ce0c630", "max": 1168138808, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_e0550cab24c7492787af40dc4b8576bf", "value": 1168138808 } }, "7266a729edfb4a44b5b1c67dc79be146": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "76dbab4873f342019c5d7624ae2c9775": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "773b802daed942f5a11f3eab3b83be08": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_172c0c6955e1428b999dcb2d133704cd", "placeholder": "​", "style": "IPY_MODEL_1bf7108774c34016a2193e2cd7639b7d", "value": " 4.92G/4.92G [00:16<00:00, 297MB/s]" } }, "77606cd2fe1b4d33a91ede944bb1dec0": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7989003a613e45f780d3f800e121543a": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7c6658cfff1a4d27af3de148184f77d9": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "80dfd3e80ceb444a83ec1fd65f9af80e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "81458e7953a349cfafccaa213b370406": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "84c34bfecda64391a609e19f131d51d4": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8659c3eddb014c3bb5931fd9e6fadad8": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_00d3286c9c1d4161bb777b7b65ae744d", "placeholder": "​", "style": "IPY_MODEL_66f27fb11edf453b8144c2dfcdc66baa", "value": "model-00002-of-00004.safetensors: 100%" } }, "87da9905a0534c26ad0712ad426ca930": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "87f474861e54432e9d533e0a89bb77da": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8ae98969541849efa356cf912ac39b1e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_f9373112649945e3b446c3e1ec274dc1", "IPY_MODEL_d49791082a304ade95c185c79fae1f41", "IPY_MODEL_616e383bb3d442bcb6edb2721a8180b6" ], "layout": "IPY_MODEL_87f474861e54432e9d533e0a89bb77da" } }, "8ba9f009e92a46fcbcbb401dc444f12e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "94073be250cd42d5b82e196e30cbf22e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9492edc02dee456f840325d913fa4e4f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "97e8877869cd4be68ff38ce745be5045": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "98b4680141ee423bb5e43c47613d8440": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_b02ffefca3f34252914e76f4a8a467dc", "IPY_MODEL_31d27bf34a74432f8e0dbfe9ecb76130", "IPY_MODEL_a3137f3669b54e84be91010c9654d985" ], "layout": "IPY_MODEL_5a2886564d3f40ceaa30b743dbe81f45" } }, "9ba6a11ffd194bf9a0900f52a7ed4d4f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_97e8877869cd4be68ff38ce745be5045", "placeholder": "​", "style": "IPY_MODEL_cc3da88e93c4499993b7bbb7d3064326", "value": "model-00001-of-00004.safetensors: 100%" } }, "9bb60a5a3710463ebe3a17f8d2a446be": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_0a08fb81165748748ccb080e6df0600f", "IPY_MODEL_603690f543114a7fb6aebd433c80bdc3", "IPY_MODEL_773b802daed942f5a11f3eab3b83be08" ], "layout": "IPY_MODEL_7989003a613e45f780d3f800e121543a" } }, "9d49589118f5432cac49650251046429": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a06dcb3bdfc84905a7222066c32fe500": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_fea1e2327d2944859af3d91c216b9008", "placeholder": "​", "style": "IPY_MODEL_320c00a5d18c45ccae634d166f1bd810", "value": "model-00004-of-00004.safetensors: 100%" } }, "a06fbde549cc47fdaddfbdb82d35d823": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "a3137f3669b54e84be91010c9654d985": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_6176990205cc499f8995c71fc6b9d4df", "placeholder": "​", "style": "IPY_MODEL_66c23ae98bcc45f18fc5c91e0e73c3e4", "value": " 4.92G/4.92G [00:16<00:00, 297MB/s]" } }, "a3dc9dfadae642b4a873705596739468": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "a63351a6715643378491ba831b3fb05d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ac1e34f4bd6c420bb6cc2fdde5f3ed4d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "acae8bbbb4a84ed49be72fecd11fb052": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_0d51fdc2c416474da04079db6579890f", "max": 4976698672, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_c4598300a77b4667b1117f9499f5ccb7", "value": 4976698672 } }, "af985cf6fa26475eb2c4dd81e0c79ff4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_8659c3eddb014c3bb5931fd9e6fadad8", "IPY_MODEL_f5fa00d96c4c49e48e1806d23a5b8570", "IPY_MODEL_080c484114f64f5591fa1287a35b46c9" ], "layout": "IPY_MODEL_14dc6a3717484c55a116612e28447dbb" } }, "b02ffefca3f34252914e76f4a8a467dc": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_15ea8fcfe097471e8fc9502a162f5904", "placeholder": "​", "style": "IPY_MODEL_c779e80c50ba4434bfa1d326c5cc9b0f", "value": "model-00003-of-00004.safetensors: 100%" } }, "b46a08cf4929422eb0f76d8d9af11249": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_1da83719e47c4196b06f3aa32056b560", "max": 1168138808, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_c4a2c88326d14fbca87cfde073755a2e", "value": 1168138808 } }, "b8a98f163ebd4ac89af08a49c0881c23": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b953419300604b8e86fc0ad003fdfd2f", "placeholder": "​", "style": "IPY_MODEL_f1865ed0fbcc40eeabdca90a43d00069", "value": "model-00003-of-00004.safetensors: 100%" } }, "b953419300604b8e86fc0ad003fdfd2f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bdf8b693821344fc97918e6cbc31c8bf": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "be3e9bf271f04eb0b119659e1af3a0ea": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c4598300a77b4667b1117f9499f5ccb7": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "c4a2c88326d14fbca87cfde073755a2e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "c52e0f34892b4daa84c1bf61500ac399": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c779e80c50ba4434bfa1d326c5cc9b0f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ca81071ab07446df96795a482ce0c630": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "cc20ffcf0c1a4656945959bf457dfd84": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "cc3da88e93c4499993b7bbb7d3064326": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "d160986df978416c9ad91d1e10fc90fc": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "d49791082a304ade95c185c79fae1f41": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_00148825ce0248b7a23eb28e3eca6749", "max": 4999802720, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f1a9b0c2431640298a6c1b258298b12d", "value": 4999802720 } }, "d598f094c3ce4daeab19fac8094cba7e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_0afc2d23514b45c9890b5d2ee4e6fa0b", "IPY_MODEL_3da5d38bf3314d3eaa7cedebae41c076", "IPY_MODEL_55e6b727a4594078beb3853cc1891308" ], "layout": "IPY_MODEL_f17fa78263414ef8b414c7bf3ac03192" } }, "d5b4bb4891ec4e44be46e9815c7e10dc": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_f55b59efcefa4ad5955d082f4bf7c637", "max": 4976698672, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_1b02e0c7d1604b1c87a327c4c4f8b0e7", "value": 4976698672 } }, "d74186bb74d142dfb683fa347b6990f7": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "dc5d555099f64a998514ebde90eeb6df": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_3d519ce3562c4e249bf392c7f43d04c0", "placeholder": "​", "style": "IPY_MODEL_cc20ffcf0c1a4656945959bf457dfd84", "value": " 1.17G/1.17G [00:03<00:00, 328MB/s]" } }, "dffa208978f34e6a9aae94ecda92fe67": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_b8a98f163ebd4ac89af08a49c0881c23", "IPY_MODEL_f0d9febe1a634a0ba7e8e50fa104dcc2", "IPY_MODEL_e23870f0c7ff40cc8fa6a1e862a4af99" ], "layout": "IPY_MODEL_87da9905a0534c26ad0712ad426ca930" } }, "e0550cab24c7492787af40dc4b8576bf": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "e1e4125eac004bae92dc1f22f673bf0e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_81458e7953a349cfafccaa213b370406", "placeholder": "​", "style": "IPY_MODEL_a3dc9dfadae642b4a873705596739468", "value": "model-00001-of-00004.safetensors: 100%" } }, "e23870f0c7ff40cc8fa6a1e862a4af99": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5e97f7c2e8f5453dafcdad0552060e60", "placeholder": "​", "style": "IPY_MODEL_4b3e7b8774df4b458bb6c6146fe3226d", "value": " 4.92G/4.92G [00:20<00:00, 317MB/s]" } }, "e7602abc26714ee890a0cf5c0c7b67e1": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_6c857e69d5204cd3b7c3bf426993ad1f", "max": 1168138808, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_2145e47428f1446fba3e62b3cde0a7f5", "value": 1168138808 } }, "e805bb6dfee34dab8870f4618d8bffdb": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e8a4b441281b4038bb0204d093411f68": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_77606cd2fe1b4d33a91ede944bb1dec0", "placeholder": "​", "style": "IPY_MODEL_f1ba439c26d64c90af2f162c74348405", "value": " 4.98G/4.98G [00:16<00:00, 296MB/s]" } }, "e8b187b40ec14db3af17a380830a35bf": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e94ca32eaa9f4714a3b05a5fdf24d02b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e9aba3d53b4d45c485a7aad649c7b465": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7c6658cfff1a4d27af3de148184f77d9", "max": 4976698672, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_7266a729edfb4a44b5b1c67dc79be146", "value": 4976698672 } }, "ea0128909a9d4801ba312a876b0cf183": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "eb94612785e64552aea8674dc8647a93": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ebe04aeaaac042aaaa0885992e45793d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ed28e180d94a4b7aa548581612e31232": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_ff4338faded5494da1ccb660e1c441ed", "IPY_MODEL_b46a08cf4929422eb0f76d8d9af11249", "IPY_MODEL_f049eb4a50f54c34912ca959d2eaf353" ], "layout": "IPY_MODEL_80dfd3e80ceb444a83ec1fd65f9af80e" } }, "ef93a2f58cc54373941f43658bb808cf": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f049eb4a50f54c34912ca959d2eaf353": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_f0ab5a46cbb0444c88ed137d8a95002b", "placeholder": "​", "style": "IPY_MODEL_f8f28ac0e149428f9fef42373c6a87d0", "value": " 1.17G/1.17G [00:03<00:00, 307MB/s]" } }, "f0ab5a46cbb0444c88ed137d8a95002b": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f0d9febe1a634a0ba7e8e50fa104dcc2": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ea0128909a9d4801ba312a876b0cf183", "max": 4915916176, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d160986df978416c9ad91d1e10fc90fc", "value": 4915916176 } }, "f114549fe8ce49638a791ca2fecb2d89": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "f17fa78263414ef8b414c7bf3ac03192": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f1865ed0fbcc40eeabdca90a43d00069": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "f1a12d7929db4309b9881853135359fc": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_76dbab4873f342019c5d7624ae2c9775", "placeholder": "​", "style": "IPY_MODEL_3cea4b431147441a8d9bd872811d5974", "value": " 4.98G/4.98G [00:16<00:00, 309MB/s]" } }, "f1a9b0c2431640298a6c1b258298b12d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "f1ba439c26d64c90af2f162c74348405": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "f3788acce34f4956b0727b58d0cf38c6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_6022a9426683420690d9b41a0ca4f870", "IPY_MODEL_e9aba3d53b4d45c485a7aad649c7b465", "IPY_MODEL_f1a12d7929db4309b9881853135359fc" ], "layout": "IPY_MODEL_58c9dec75a3346b1b787f88dd510d254" } }, "f55b59efcefa4ad5955d082f4bf7c637": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f5fa00d96c4c49e48e1806d23a5b8570": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5798e5118430439fb1f6bf29e1bafe58", "max": 4999802720, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_357f367cf74146b8825be371acd51d06", "value": 4999802720 } }, "f8f28ac0e149428f9fef42373c6a87d0": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "f9373112649945e3b446c3e1ec274dc1": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_e805bb6dfee34dab8870f4618d8bffdb", "placeholder": "​", "style": "IPY_MODEL_be3e9bf271f04eb0b119659e1af3a0ea", "value": "model-00002-of-00004.safetensors: 100%" } }, "fea1e2327d2944859af3d91c216b9008": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ff4338faded5494da1ccb660e1c441ed": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_519147a10b984befbd0f255f78c1f66a", "placeholder": "​", "style": "IPY_MODEL_562e82438dbe41b793ff488b8447c5bf", "value": "model-00004-of-00004.safetensors: 100%" } } } } }, "nbformat": 4, "nbformat_minor": 5 }