2024-10-05 07:30:47 -05:00
{
"cells": [
{
"cell_type": "markdown",
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
},
2024-10-05 07:30:47 -05:00
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
},
2024-10-05 07:30:47 -05:00
"source": [
"# Llama 3.2 From Scratch (A Standalone Notebook)"
]
},
{
"cell_type": "markdown",
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
},
2024-10-05 07:30:47 -05:00
"source": [
"- This notebook is purposefully minimal and focuses on the code to implement the Llama 3.2 1B and 3B LLMs\n",
"- For a step-by-step guide that explains the individual components and the relationship between GPT, Llama 2, and Llama 3, please see the following companion notebooks:\n",
" - [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n",
" - [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n",
" \n",
2024-10-25 15:27:23 -05:00
"\n",
2024-10-05 07:30:47 -05:00
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/llama32.webp\" width=\"700px\">\n",
" \n",
" \n",
"- About the code:\n",
" - all code is my own code, mapping the Llama 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))\n",
" - the tokenizer code is inspired by the original [Llama 3 tokenizer code](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py), which Meta AI used to to extends the Tiktoken GPT-4 tokenizer\n",
" - the RoPE rescaling section is inspired by the [_compute_llama3_parameters function](https://github.com/huggingface/transformers/blob/5c1027bf09717f664b579e01cbb8ec3ef5aeb140/src/transformers/modeling_rope_utils.py#L329-L347) in the `transformers` library"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": null,
"id": "7c201adb-747e-437b-9a62-442802941e01",
2024-10-05 07:30:47 -05:00
"metadata": {},
"outputs": [],
"source": [
"# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 3,
2024-10-05 07:30:47 -05:00
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
2024-10-25 15:27:23 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
"outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
},
2024-10-05 07:30:47 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"blobfile version: 3.0.0\n",
2024-10-25 15:27:23 -05:00
"huggingface_hub version: 0.25.2\n",
"tiktoken version: 0.8.0\n",
"torch version: 2.5.0\n"
2024-10-05 07:30:47 -05:00
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\n",
" \"blobfile\", # to download pretrained weights\n",
" \"huggingface_hub\", # to download pretrained weights\n",
" \"tiktoken\", # to implement the tokenizer\n",
" \"torch\", # to implement the model\n",
"]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
},
2024-10-05 07:30:47 -05:00
"source": [
" \n",
"# 1. Architecture code"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 4,
2024-10-05 07:30:47 -05:00
"id": "82076c21-9331-4dcd-b017-42b046cf1a60",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "82076c21-9331-4dcd-b017-42b046cf1a60"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"\n",
"class FeedForward(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
" self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
" self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
"\n",
" def forward(self, x):\n",
" x_fc1 = self.fc1(x)\n",
" x_fc2 = self.fc2(x)\n",
" x = nn.functional.silu(x_fc1) * x_fc2\n",
" return self.fc3(x)"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 5,
2024-10-05 07:30:47 -05:00
"id": "4b9a346f-5826-4083-9162-abd56afc03f0",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "4b9a346f-5826-4083-9162-abd56afc03f0"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
2024-10-21 19:58:38 -05:00
"def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n",
2024-10-05 07:30:47 -05:00
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
2024-10-23 18:07:49 -05:00
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
2024-10-05 07:30:47 -05:00
"\n",
" # Frequency adjustments\n",
" if freq_config is not None:\n",
" low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n",
" high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n",
"\n",
" wavelen = 2 * torch.pi / inv_freq\n",
"\n",
" inv_freq_llama = torch.where(\n",
" wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n",
" )\n",
"\n",
" smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n",
" freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n",
" )\n",
"\n",
" smoothed_inv_freq = (\n",
" (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n",
" )\n",
"\n",
" is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n",
" inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n",
" inv_freq = inv_freq_llama\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
"\n",
" # Precompute sine and cosine\n",
" cos = torch.cos(angles)\n",
" sin = torch.sin(angles)\n",
"\n",
" return cos, sin\n",
"\n",
"\n",
"def compute_rope(x, cos, sin):\n",
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
"\n",
" # Split x into first half and second half\n",
" x1 = x[..., : head_dim // 2] # First half\n",
" x2 = x[..., head_dim // 2 :] # Second half\n",
"\n",
" # Adjust sin and cos shapes\n",
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
"\n",
" # Apply the rotary transformation\n",
" rotated = torch.cat((-x2, x1), dim=-1)\n",
" x_rotated = (x * cos) + (rotated * sin)\n",
"\n",
" return x_rotated.to(dtype=x.dtype)"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 6,
2024-10-05 07:30:47 -05:00
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
2024-10-06 12:49:04 -05:00
"class SharedBuffers:\n",
" _buffers = {}\n",
"\n",
" @staticmethod\n",
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
"\n",
" if key not in SharedBuffers._buffers:\n",
" # Create or fetch the buffers\n",
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
" if dtype is not None:\n",
" cos = cos.to(dtype)\n",
" sin = sin.to(dtype)\n",
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
"\n",
" return SharedBuffers._buffers[key]\n",
"\n",
"\n",
2024-10-05 07:30:47 -05:00
"class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n",
" self, d_in, d_out, context_length, num_heads,\n",
" num_kv_groups,\n",
" rope_base=10_000,\n",
" rope_config=None,\n",
" dtype=None\n",
" ):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
" assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
"\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads\n",
"\n",
" self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
" self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
" self.num_kv_groups = num_kv_groups\n",
" self.group_size = num_heads // num_kv_groups\n",
"\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
"\n",
2024-10-06 12:49:04 -05:00
" # Fetch buffers using SharedBuffers\n",
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
" self.register_buffer(\"mask\", mask)\n",
"\n",
2024-10-05 07:30:47 -05:00
" self.register_buffer(\"cos\", cos)\n",
" self.register_buffer(\"sin\", sin)\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
"\n",
" queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n",
" keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
" values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
"\n",
" # Reshape queries, keys, and values\n",
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
" keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
"\n",
" # Transpose keys, values, and queries\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
"\n",
" # Apply RoPE\n",
" keys = compute_rope(keys, self.cos, self.sin)\n",
" queries = compute_rope(queries, self.cos, self.sin)\n",
"\n",
" # Expand keys and values to match the number of heads\n",
" # Shape: (b, num_heads, num_tokens, head_dim)\n",
" keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" # For example, before repeat_interleave along dim=1 (query groups):\n",
" # [K1, K2]\n",
" # After repeat_interleave (each query group is repeated group_size times):\n",
" # [K1, K1, K2, K2]\n",
" # If we used regular repeat instead of repeat_interleave, we'd get:\n",
" # [K1, K2, K1, K2]\n",
"\n",
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
" # Shape: (b, num_heads, num_tokens, num_tokens)\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
"\n",
" # Original mask truncated to the number of tokens and converted to boolean\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
"\n",
" # Use the mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
"\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" assert keys.shape[-1] == self.head_dim\n",
"\n",
" # Shape: (b, num_tokens, num_heads, head_dim)\n",
" context_vec = (attn_weights @ values).transpose(1, 2)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n",
" context_vec = self.out_proj(context_vec) # optional projection\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 7,
2024-10-05 07:30:47 -05:00
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"class TransformerBlock(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.att = GroupedQueryAttention(\n",
" d_in=cfg[\"emb_dim\"],\n",
" d_out=cfg[\"emb_dim\"],\n",
" context_length=cfg[\"context_length\"],\n",
" num_heads=cfg[\"n_heads\"],\n",
" num_kv_groups=cfg[\"n_kv_groups\"],\n",
" rope_base=cfg[\"rope_base\"],\n",
" rope_config=cfg[\"rope_freq\"],\n",
" dtype=cfg[\"dtype\"]\n",
" )\n",
" self.ff = FeedForward(cfg)\n",
" self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
" self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
"\n",
" def forward(self, x):\n",
" # Shortcut connection for attention block\n",
" shortcut = x\n",
" x = self.norm1(x)\n",
" x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n",
" x = x + shortcut # Add the original input back\n",
"\n",
" # Shortcut connection for feed-forward block\n",
" shortcut = x\n",
" x = self.norm2(x)\n",
" x = self.ff(x.to(torch.bfloat16))\n",
" x = x + shortcut # Add the original input back\n",
"\n",
" return x"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 8,
2024-10-05 07:30:47 -05:00
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"class Llama3Model(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
"\n",
" self.trf_blocks = nn.Sequential(\n",
" *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
"\n",
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" def forward(self, in_idx):\n",
" tok_embeds = self.tok_emb(in_idx)\n",
" x = tok_embeds\n",
" x = self.trf_blocks(x)\n",
" x = self.final_norm(x)\n",
" logits = self.out_head(x.to(torch.bfloat16))\n",
" return logits"
]
},
{
"cell_type": "markdown",
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
},
2024-10-05 07:30:47 -05:00
"source": [
" \n",
"# 2. Initialize model"
]
},
{
"cell_type": "markdown",
"id": "23dea40c-fe20-4a75-be25-d6fce5863c01",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "23dea40c-fe20-4a75-be25-d6fce5863c01"
},
2024-10-05 07:30:47 -05:00
"source": [
"- The remainder of this notebook uses the Llama 3.2 1B model; to use the 3B model variant, just uncomment the second configuration file in the following code cell"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 9,
2024-10-05 07:30:47 -05:00
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"# Llama 3.2 1B\n",
"\n",
"LLAMA32_CONFIG = {\n",
2024-10-06 12:49:04 -05:00
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # Context length\n",
" \"emb_dim\": 2048, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 16, # Number of layers\n",
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
2024-10-26 04:08:06 +02:00
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
2024-10-06 12:49:04 -05:00
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
2024-10-05 07:30:47 -05:00
" \"factor\": 32.0,\n",
" \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n",
" \"original_context_length\": 8192,\n",
" }\n",
"}\n",
"\n",
"# Llama 3.2 3B\n",
"\n",
"# LLAMA32_CONFIG = {\n",
2024-10-06 12:49:04 -05:00
"# \"vocab_size\": 128_256, # Vocabulary size\n",
2024-10-26 04:08:06 +02:00
"# \"context_length\": 131_072, # Context length\n",
2024-10-06 12:49:04 -05:00
"# \"emb_dim\": 3072, # Embedding dimension\n",
"# \"n_heads\": 24, # Number of attention heads\n",
"# \"n_layers\": 28, # Number of layers\n",
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
2024-10-26 04:08:06 +02:00
"# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
2024-10-06 12:49:04 -05:00
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
"# \"rope_freq\": { # RoPE frequency scaling\n",
2024-10-05 07:30:47 -05:00
"# \"factor\": 32.0,\n",
"# \"low_freq_factor\": 1.0,\n",
"# \"high_freq_factor\": 4.0,\n",
"# \"original_context_length\": 8192,\n",
"# }\n",
"# }\n",
"\n",
"LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\""
]
},
2024-10-06 12:49:04 -05:00
{
"cell_type": "markdown",
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06"
},
2024-10-06 12:49:04 -05:00
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
2024-10-05 07:30:47 -05:00
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 10,
2024-10-06 12:49:04 -05:00
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New RoPE theta: 31250.0\n"
]
}
],
2024-10-06 12:49:04 -05:00
"source": [
2024-10-25 15:27:23 -05:00
"old_context_length = LLAMA32_CONFIG[\"context_length\"]\n",
"LLAMA32_CONFIG[\"context_length\"] = 8192\n",
"\n",
"\n",
"def rescale_theta(theta_old, context_length_old, context_length_new):\n",
" scaling_factor = context_length_new / context_length_old\n",
" theta_new = theta_old * scaling_factor\n",
" return theta_new\n",
"\n",
"LLAMA32_CONFIG[\"rope_base\"] = rescale_theta(\n",
" LLAMA32_CONFIG[\"rope_base\"],\n",
" old_context_length,\n",
" LLAMA32_CONFIG[\"context_length\"]\n",
")\n",
"\n",
"print(\"New RoPE theta:\", LLAMA32_CONFIG[\"rope_base\"])"
2024-10-06 12:49:04 -05:00
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 11,
2024-10-05 07:30:47 -05:00
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"model = Llama3Model(LLAMA32_CONFIG)"
]
},
2024-10-06 12:49:04 -05:00
{
"cell_type": "markdown",
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1"
},
2024-10-06 12:49:04 -05:00
"source": [
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
]
},
2024-10-05 07:30:47 -05:00
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 12,
2024-10-06 12:49:04 -05:00
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
2024-10-25 15:27:23 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
"outputId": "8efc4937-e616-40d0-cd59-670d7eb3e841"
},
2024-10-06 12:49:04 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"# Check buffers\n",
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
2024-10-25 15:27:23 -05:00
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin)"
2024-10-06 12:49:04 -05:00
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 13,
2024-10-05 07:30:47 -05:00
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
2024-10-25 15:27:23 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"outputId": "00d7e983-262e-4c65-f322-f4d999311988"
},
2024-10-05 07:30:47 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-10-05 09:20:54 -05:00
"Total number of parameters: 1,498,482,688\n",
"\n",
"Total number of unique parameters: 1,235,814,400\n"
2024-10-05 07:30:47 -05:00
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters())\n",
2024-10-05 09:20:54 -05:00
"print(f\"Total number of parameters: {total_params:,}\")\n",
"\n",
"# Account for weight tying\n",
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
2024-10-05 07:30:47 -05:00
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 14,
2024-10-05 07:30:47 -05:00
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
2024-10-25 15:27:23 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5"
},
2024-10-05 07:30:47 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-10-06 12:49:04 -05:00
"float32 (PyTorch default): 11.42 GB\n",
"bfloat16: 5.71 GB\n"
2024-10-05 07:30:47 -05:00
]
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
" # Calculate total number of elements per parameter\n",
" param_size = param.numel()\n",
" total_params += param_size\n",
" # Check if gradients are stored for this parameter\n",
" if param.requires_grad:\n",
" total_grads += param_size\n",
"\n",
" # Calculate buffer size (non-parameters that require memory)\n",
" total_buffers = sum(buf.numel() for buf in model.buffers())\n",
"\n",
" # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
" # We assume parameters and gradients are stored in the same type as input dtype\n",
" element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
" total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
"\n",
" # Convert bytes to gigabytes\n",
" total_memory_gb = total_memory_bytes / (1024**3)\n",
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 15,
2024-10-05 07:30:47 -05:00
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\")\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"\n",
"model.to(device);"
]
},
{
"cell_type": "markdown",
"id": "78e091e1-afa8-4d23-9aea-cced86181bfd",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "78e091e1-afa8-4d23-9aea-cced86181bfd"
},
2024-10-05 07:30:47 -05:00
"source": [
" \n",
"# 3. Load tokenizer"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 16,
2024-10-05 07:30:47 -05:00
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"\n",
"import tiktoken\n",
"from tiktoken.load import load_tiktoken_bpe\n",
"\n",
"\n",
"class Tokenizer:\n",
" def __init__(self, model_path):\n",
" assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
" mergeable_ranks = load_tiktoken_bpe(model_path)\n",
"\n",
" self.special_tokens = {\n",
" \"<|begin_of_text|>\": 128000,\n",
" \"<|end_of_text|>\": 128001,\n",
" \"<|start_header_id|>\": 128006,\n",
" \"<|end_header_id|>\": 128007,\n",
" \"<|eot_id|>\": 128009,\n",
" }\n",
" self.special_tokens.update({\n",
" f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n",
" })\n",
"\n",
" self.model = tiktoken.Encoding(\n",
" name=Path(model_path).name,\n",
" pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
" mergeable_ranks=mergeable_ranks,\n",
" special_tokens=self.special_tokens\n",
" )\n",
"\n",
"\n",
" def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n",
" if bos:\n",
" tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
" else:\n",
" tokens = []\n",
"\n",
" tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
"\n",
" if eos:\n",
" tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n",
" return tokens\n",
"\n",
" def decode(self, tokens):\n",
" return self.model.decode(tokens)\n",
2024-10-25 15:27:23 -05:00
"\n",
2024-10-05 07:30:47 -05:00
"\n",
"class ChatFormat:\n",
" def __init__(self, tokenizer):\n",
" self.tokenizer = tokenizer\n",
"\n",
" def encode_header(self, message):\n",
" tokens = []\n",
" tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n",
" tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n",
" tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
" tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
" return tokens\n",
"\n",
" def encode(self, text):\n",
" message = {\n",
" \"role\": \"user\",\n",
" \"content\": text\n",
" }\n",
"\n",
" tokens = self.encode_header(message)\n",
" tokens.extend(\n",
" self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
" )\n",
" tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
" return tokens\n",
"\n",
" def decode(self, token_ids):\n",
" return self.tokenizer.decode(token_ids)"
]
},
{
"cell_type": "markdown",
"id": "b771b60c-c198-4b30-bf10-42031197ae86",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "b771b60c-c198-4b30-bf10-42031197ae86"
},
2024-10-05 07:30:47 -05:00
"source": [
2024-10-06 16:56:55 +02:00
"- Please note that Meta AI requires that you accept the Llama 3.2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository to accept the terms\n",
2024-10-05 07:30:47 -05:00
"- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n",
"\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/settings.webp?1\" width=\"300px\">\n",
"\n",
"- Then, create and copy the access token so you can copy & paste it into the next code cell\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/access-token.webp?1\" width=\"600px\">"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 17,
2024-10-05 09:20:54 -05:00
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
2024-10-25 15:27:23 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
"outputId": "e6e6dc05-7330-45bc-a9a7-331919155bdd"
},
2024-10-06 16:56:55 +02:00
"outputs": [
{
2024-10-25 15:27:23 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
"Token is valid (permission: read).\n",
"Your token has been saved to /teamspace/studios/this_studio/.cache/huggingface/token\n",
"Login successful\n"
]
2024-10-06 16:56:55 +02:00
}
],
2024-10-05 07:30:47 -05:00
"source": [
"from huggingface_hub import login\n",
"\n",
"login()"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 18,
2024-10-05 07:30:47 -05:00
"id": "986bc1a0-804f-4154-80f8-44cefbee1368",
2024-10-25 15:27:23 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 141,
"referenced_widgets": [
"a1608feac06d4687967a3e398f01c489",
"518fb202e4b44aaba47f07d1a61b6762",
"672cdc5aea954de3af851c001a667ad3",
"eebf8874618746b39cf4a21a2728dc7f",
"5176834aa8784bba9ec21234b87a8948",
"e2dc407afcd945c798e30597fddfcb3c",
"0dccd57dcc5c43a588157cef957c07e8",
"33ca0cdf2c7f41598a381c4ebe6a4ee1",
"ee44487f58454dacb522b1e084ffb733",
"d2c41e71a3f441deaed091b620ac5603",
"3326b6141a1a4eba9f316df528a9b99a"
]
},
"id": "986bc1a0-804f-4154-80f8-44cefbee1368",
"outputId": "5dd7334b-4c71-465a-94d2-c3e95b9ddc58"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"from huggingface_hub import hf_hub_download\n",
"\n",
"tokenizer_file_path = hf_hub_download(\n",
" repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
" filename=\"original/tokenizer.model\",\n",
2024-10-25 15:27:23 -05:00
" local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
2024-10-05 07:30:47 -05:00
")"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 19,
"id": "_gBhxDtU_nxo",
"metadata": {
"id": "_gBhxDtU_nxo"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"tokenizer = Tokenizer(tokenizer_file_path)\n",
"chat_tokenizer = ChatFormat(tokenizer)"
]
},
{
"cell_type": "markdown",
"id": "c172f89f-d301-439f-b809-46169e5f5945",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "c172f89f-d301-439f-b809-46169e5f5945"
},
2024-10-05 07:30:47 -05:00
"source": [
" \n",
"# 4. Load pretrained weights"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 20,
2024-10-05 07:30:47 -05:00
"id": "75166128-5899-4995-9b88-9672e135650e",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "75166128-5899-4995-9b88-9672e135650e"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"def assign(left, right, tensor_name=\"unknown\"):\n",
" if left.shape != right.shape:\n",
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
"\n",
" if isinstance(right, torch.Tensor):\n",
" return torch.nn.Parameter(right.clone().detach())\n",
" else:\n",
" return torch.nn.Parameter(torch.tensor(right))\n",
"\n",
"\n",
"def load_weights_into_llama(model, param_config, params):\n",
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
"\n",
" for l in range(param_config[\"n_layers\"]):\n",
"\n",
" # Load attention weights\n",
" model.trf_blocks[l].att.W_query.weight = assign(\n",
" model.trf_blocks[l].att.W_query.weight,\n",
" params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.q_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].att.W_key.weight = assign(\n",
" model.trf_blocks[l].att.W_key.weight,\n",
" params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.k_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].att.W_value.weight = assign(\n",
" model.trf_blocks[l].att.W_value.weight,\n",
" params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.v_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].att.out_proj.weight = assign(\n",
" model.trf_blocks[l].att.out_proj.weight,\n",
" params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.o_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].norm1.weight = assign(\n",
" model.trf_blocks[l].norm1.weight,\n",
" params[f\"model.layers.{l}.input_layernorm.weight\"],\n",
" f\"model.layers.{l}.input_layernorm.weight\"\n",
" )\n",
"\n",
" # Load FeedForward weights\n",
" model.trf_blocks[l].ff.fc1.weight = assign(\n",
" model.trf_blocks[l].ff.fc1.weight,\n",
" params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.gate_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].ff.fc2.weight = assign(\n",
" model.trf_blocks[l].ff.fc2.weight,\n",
" params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.up_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].ff.fc3.weight = assign(\n",
" model.trf_blocks[l].ff.fc3.weight,\n",
" params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.down_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].norm2.weight = assign(\n",
" model.trf_blocks[l].norm2.weight,\n",
" params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
" f\"model.layers.{l}.post_attention_layernorm.weight\"\n",
" )\n",
"\n",
" # Load output layer weights\n",
" model.final_norm.weight = assign(model.final_norm.weight, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
"\n",
" if \"lm_head.weight\" in params.keys():\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
" else:\n",
2024-10-05 09:20:54 -05:00
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
" print(\"Model uses weight tying.\")"
2024-10-05 07:30:47 -05:00
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 21,
2024-10-05 07:30:47 -05:00
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
2024-10-25 15:27:23 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17,
"referenced_widgets": [
"9881b6995c3f49dc89e6992fd9ab660b",
"17a3174e65c54476b2e0d1faf8f011ca",
"1bbf2e62c0754d1593beb4105a7f1ac1",
"b82112e1dec645d98aa1c1ba64abcb61",
"271e2bd6a35e4a8b92de8697f7c0be5f",
"90a79523187446dfa692723b2e5833a7",
"431ffb83b8c14bf182f0430e07ea6154",
"a8f1b72a33dd4b548de23fbd95e0da18",
"25cc36132d384189acfbecc59483134b",
"bfd06423ad544218968648016e731a46",
"d029630b63ff44cf807ade428d2eb421"
]
},
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
},
2024-10-05 09:20:54 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model uses weight tying.\n"
]
}
],
2024-10-05 07:30:47 -05:00
"source": [
"from safetensors.torch import load_file\n",
"\n",
"\n",
"if LLAMA_SIZE_STR == \"1B\":\n",
" weights_file = hf_hub_download(\n",
" repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
" filename=f\"model.safetensors\",\n",
2024-10-25 15:27:23 -05:00
" local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
2024-10-05 07:30:47 -05:00
" )\n",
" combined_weights = load_file(weights_file)\n",
"\n",
"\n",
"else:\n",
" combined_weights = {}\n",
2024-10-05 09:20:54 -05:00
" for i in range(1, 3):\n",
2024-10-05 07:30:47 -05:00
" weights_file = hf_hub_download(\n",
" repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
" filename=f\"model-0000{i}-of-00002.safetensors\",\n",
2024-10-25 15:27:23 -05:00
" local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
2024-10-05 07:30:47 -05:00
" )\n",
" current_weights = load_file(weights_file)\n",
" combined_weights.update(current_weights)\n",
2024-10-25 15:27:23 -05:00
"\n",
2024-10-05 07:30:47 -05:00
"\n",
"load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)\n",
2024-10-26 04:08:06 +02:00
"model.to(device)\n",
"del combined_weights # free up memory"
2024-10-05 07:30:47 -05:00
]
},
2024-10-05 09:20:54 -05:00
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 22,
2024-10-05 09:20:54 -05:00
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37"
},
2024-10-05 09:20:54 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Weight tying: True\n"
]
}
],
"source": [
"print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))"
]
},
2024-10-05 07:30:47 -05:00
{
"cell_type": "markdown",
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "57d07df1-4401-4792-b549-7c4cc5632323"
},
2024-10-05 07:30:47 -05:00
"source": [
" \n",
"# 5. Generate text"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 23,
2024-10-05 07:30:47 -05:00
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
},
2024-10-05 07:30:47 -05:00
"outputs": [],
"source": [
"def text_to_token_ids(text, tokenizer):\n",
" encoded = tokenizer.encode(text)\n",
" encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension\n",
" return encoded_tensor\n",
"\n",
"\n",
"def token_ids_to_text(token_ids, tokenizer):\n",
" flat = token_ids.squeeze(0) # remove batch dimension\n",
" return tokenizer.decode(flat.tolist())\n",
"\n",
"\n",
"def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n",
"\n",
" # For-loop is the same as before: Get logits, and only focus on last time step\n",
" for _ in range(max_new_tokens):\n",
" idx_cond = idx[:, -context_size:]\n",
" with torch.no_grad():\n",
" logits = model(idx_cond)\n",
" logits = logits[:, -1, :]\n",
"\n",
" # New: Filter logits with top_k sampling\n",
" if top_k is not None:\n",
" # Keep only top_k values\n",
" top_logits, _ = torch.topk(logits, top_k)\n",
" min_val = top_logits[:, -1]\n",
" logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
"\n",
" # New: Apply temperature scaling\n",
" if temperature > 0.0:\n",
" logits = logits / temperature\n",
"\n",
" # Apply softmax to get probabilities\n",
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
"\n",
" # Sample from the distribution\n",
" idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)\n",
"\n",
" # Otherwise same as before: get idx of the vocab entry with the highest logits value\n",
" else:\n",
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n",
"\n",
" if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n",
" break\n",
"\n",
" # Same as before: append sampled index to the running sequence\n",
" idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n",
"\n",
" return idx"
]
},
{
"cell_type": "code",
2024-10-25 15:27:23 -05:00
"execution_count": 24,
2024-10-05 07:30:47 -05:00
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
},
2024-10-05 07:30:47 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
2024-10-25 15:27:23 -05:00
" Llamas are herbivores, which means they primarily eat plants. Their diet consists mainly of:\n",
2024-10-05 07:30:47 -05:00
"\n",
2024-10-25 15:27:23 -05:00
"1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and grassy meadows.\n",
2024-10-05 07:30:47 -05:00
"2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n",
2024-10-25 15:27:23 -05:00
"3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas. It is high in protein and fiber.\n",
"4. Other plants: Llamas will also eat other plants, such as wild grasses, shrubs, and trees.\n",
2024-10-05 07:30:47 -05:00
"\n",
2024-10-25 15:27:23 -05:00
"It's worth noting that the diet of llamas can vary depending on the region, climate,\n"
2024-10-05 07:30:47 -05:00
]
}
],
"source": [
"PROMPT = \"What do llamas eat?\"\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"token_ids = generate(\n",
" model=model,\n",
" idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n",
" max_new_tokens=150,\n",
" context_size=LLAMA32_CONFIG[\"context_length\"],\n",
" top_k=1,\n",
" temperature=0.\n",
")\n",
"\n",
"output_text = token_ids_to_text(token_ids, tokenizer)\n",
"\n",
"\n",
"def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n",
" # Find the index of the first occurrence of \"<|end_header_id|>\"\n",
" index = text.find(header_end)\n",
"\n",
" if index != -1:\n",
" # Return the substring starting after \"<|end_header_id|>\"\n",
" return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace\n",
" else:\n",
" # If the token is not found, return the original text\n",
" return text\n",
"\n",
"print(\"Output text:\\n\", clean_text(output_text))"
]
},
{
"cell_type": "markdown",
"id": "549324d6-5c71-4147-ae21-2e67675faa3d",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "549324d6-5c71-4147-ae21-2e67675faa3d"
},
2024-10-05 07:30:47 -05:00
"source": [
" \n",
"# What's next?"
]
},
{
"cell_type": "markdown",
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
2024-10-25 15:27:23 -05:00
"metadata": {
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
},
2024-10-05 07:30:47 -05:00
"source": [
"- The notebook was kept purposefully minimal; if you are interested in additional explanation about the individual components, check out the following two companion notebooks:\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/gpt-and-all-llamas.webp\">\n",
"\n",
" 1. [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n",
" 2. [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n",
" \n",
"- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
"\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
]
}
],
"metadata": {
2024-10-25 15:27:23 -05:00
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"provenance": []
},
2024-10-05 07:30:47 -05:00
"kernelspec": {
2024-10-25 15:27:23 -05:00
"display_name": "Python 3 (ipykernel)",
2024-10-05 07:30:47 -05:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-10-25 15:27:23 -05:00
"version": "3.11.4"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"0dccd57dcc5c43a588157cef957c07e8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
},
"17a3174e65c54476b2e0d1faf8f011ca": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_90a79523187446dfa692723b2e5833a7",
"placeholder": " ",
"style": "IPY_MODEL_431ffb83b8c14bf182f0430e07ea6154",
"tabbable": null,
"tooltip": null,
"value": "model.safetensors: 35%"
}
},
"1bbf2e62c0754d1593beb4105a7f1ac1": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_a8f1b72a33dd4b548de23fbd95e0da18",
"max": 2471645608,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_25cc36132d384189acfbecc59483134b",
"tabbable": null,
"tooltip": null,
"value": 880803840
}
},
"25cc36132d384189acfbecc59483134b": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"271e2bd6a35e4a8b92de8697f7c0be5f": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3326b6141a1a4eba9f316df528a9b99a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
},
"33ca0cdf2c7f41598a381c4ebe6a4ee1": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"431ffb83b8c14bf182f0430e07ea6154": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
},
"5176834aa8784bba9ec21234b87a8948": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"518fb202e4b44aaba47f07d1a61b6762": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_e2dc407afcd945c798e30597fddfcb3c",
"placeholder": " ",
"style": "IPY_MODEL_0dccd57dcc5c43a588157cef957c07e8",
"tabbable": null,
"tooltip": null,
"value": "tokenizer.model: 100%"
}
},
"672cdc5aea954de3af851c001a667ad3": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_33ca0cdf2c7f41598a381c4ebe6a4ee1",
"max": 2183982,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_ee44487f58454dacb522b1e084ffb733",
"tabbable": null,
"tooltip": null,
"value": 2183982
}
},
"90a79523187446dfa692723b2e5833a7": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9881b6995c3f49dc89e6992fd9ab660b": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_17a3174e65c54476b2e0d1faf8f011ca",
"IPY_MODEL_1bbf2e62c0754d1593beb4105a7f1ac1",
"IPY_MODEL_b82112e1dec645d98aa1c1ba64abcb61"
],
"layout": "IPY_MODEL_271e2bd6a35e4a8b92de8697f7c0be5f",
"tabbable": null,
"tooltip": null
}
},
"a1608feac06d4687967a3e398f01c489": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_518fb202e4b44aaba47f07d1a61b6762",
"IPY_MODEL_672cdc5aea954de3af851c001a667ad3",
"IPY_MODEL_eebf8874618746b39cf4a21a2728dc7f"
],
"layout": "IPY_MODEL_5176834aa8784bba9ec21234b87a8948",
"tabbable": null,
"tooltip": null
}
},
"a8f1b72a33dd4b548de23fbd95e0da18": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b82112e1dec645d98aa1c1ba64abcb61": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_bfd06423ad544218968648016e731a46",
"placeholder": " ",
"style": "IPY_MODEL_d029630b63ff44cf807ade428d2eb421",
"tabbable": null,
"tooltip": null,
"value": " 870M/2.47G [00:20<00:37, 42.8MB/s]"
}
},
"bfd06423ad544218968648016e731a46": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d029630b63ff44cf807ade428d2eb421": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
},
"d2c41e71a3f441deaed091b620ac5603": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e2dc407afcd945c798e30597fddfcb3c": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"ee44487f58454dacb522b1e084ffb733": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"eebf8874618746b39cf4a21a2728dc7f": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_d2c41e71a3f441deaed091b620ac5603",
"placeholder": " ",
"style": "IPY_MODEL_3326b6141a1a4eba9f316df528a9b99a",
"tabbable": null,
"tooltip": null,
"value": " 2.18M/2.18M [00:00<00:00, 9.47MB/s]"
}
}
}
2024-10-05 07:30:47 -05:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}