LLMs-from-scratch/ch05/07_gpt_to_llama/standalone-llama32.ipynb
Sebastian Raschka 67e0680210
Disable mask saving as weight in Llama 3 model (#604)
* Disable mask saving as weight

* update pixi

* update pixi
2025-04-06 09:33:36 -05:00

1946 lines
67 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
"metadata": {
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
"metadata": {
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
},
"source": [
"# Llama 3.2 From Scratch (A Standalone Notebook)"
]
},
{
"cell_type": "markdown",
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
"metadata": {
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
},
"source": [
"- This notebook is purposefully minimal and focuses on the code to implement the Llama 3.2 1B and 3B LLMs\n",
"- For a step-by-step guide that explains the individual components and the relationship between GPT, Llama 2, and Llama 3, please see the following companion notebooks:\n",
" - [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n",
" - [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n",
" \n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/llama32.webp\" width=\"700px\">\n",
" \n",
" \n",
"- About the code:\n",
" - all code is my own code, mapping the Llama 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))\n",
" - the tokenizer code is inspired by the original [Llama 3 tokenizer code](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py), which Meta AI used to extend the Tiktoken GPT-4 tokenizer\n",
" - the RoPE rescaling section is inspired by the [_compute_llama3_parameters function](https://github.com/huggingface/transformers/blob/5c1027bf09717f664b579e01cbb8ec3ef5aeb140/src/transformers/modeling_rope_utils.py#L329-L347) in the `transformers` library"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c201adb-747e-437b-9a62-442802941e01",
"metadata": {},
"outputs": [],
"source": [
"# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
"outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"blobfile version: 3.0.0\n",
"huggingface_hub version: 0.25.2\n",
"tiktoken version: 0.8.0\n",
"torch version: 2.5.0\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\n",
" \"blobfile\", # to download pretrained weights\n",
" \"huggingface_hub\", # to download pretrained weights\n",
" \"tiktoken\", # to implement the tokenizer\n",
" \"torch\", # to implement the model\n",
"]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
"metadata": {
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
},
"source": [
"&nbsp;\n",
"# 1. Architecture code"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "82076c21-9331-4dcd-b017-42b046cf1a60",
"metadata": {
"id": "82076c21-9331-4dcd-b017-42b046cf1a60"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"\n",
"class FeedForward(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
" self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
" self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
"\n",
" def forward(self, x):\n",
" x_fc1 = self.fc1(x)\n",
" x_fc2 = self.fc2(x)\n",
" x = nn.functional.silu(x_fc1) * x_fc2\n",
" return self.fc3(x)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4b9a346f-5826-4083-9162-abd56afc03f0",
"metadata": {
"id": "4b9a346f-5826-4083-9162-abd56afc03f0"
},
"outputs": [],
"source": [
"def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n",
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
"\n",
" # Frequency adjustments\n",
" if freq_config is not None:\n",
" low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n",
" high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n",
"\n",
" wavelen = 2 * torch.pi / inv_freq\n",
"\n",
" inv_freq_llama = torch.where(\n",
" wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n",
" )\n",
"\n",
" smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n",
" freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n",
" )\n",
"\n",
" smoothed_inv_freq = (\n",
" (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n",
" )\n",
"\n",
" is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n",
" inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n",
" inv_freq = inv_freq_llama\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
"\n",
" # Precompute sine and cosine\n",
" cos = torch.cos(angles)\n",
" sin = torch.sin(angles)\n",
"\n",
" return cos, sin\n",
"\n",
"\n",
"def compute_rope(x, cos, sin):\n",
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
"\n",
" # Split x into first half and second half\n",
" x1 = x[..., : head_dim // 2] # First half\n",
" x2 = x[..., head_dim // 2 :] # Second half\n",
"\n",
" # Adjust sin and cos shapes\n",
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
"\n",
" # Apply the rotary transformation\n",
" rotated = torch.cat((-x2, x1), dim=-1)\n",
" x_rotated = (x * cos) + (rotated * sin)\n",
"\n",
" return x_rotated.to(dtype=x.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
"metadata": {
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
},
"outputs": [],
"source": [
"class SharedBuffers:\n",
" _buffers = {}\n",
"\n",
" @staticmethod\n",
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
"\n",
" if key not in SharedBuffers._buffers:\n",
" # Create or fetch the buffers\n",
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
" if dtype is not None:\n",
" cos = cos.to(dtype)\n",
" sin = sin.to(dtype)\n",
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
"\n",
" return SharedBuffers._buffers[key]\n",
"\n",
"\n",
"class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n",
" self, d_in, d_out, context_length, num_heads,\n",
" num_kv_groups,\n",
" rope_base=10_000,\n",
" rope_config=None,\n",
" dtype=None\n",
" ):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
" assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
"\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads\n",
"\n",
" self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
" self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
" self.num_kv_groups = num_kv_groups\n",
" self.group_size = num_heads // num_kv_groups\n",
"\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
"\n",
" # Fetch buffers using SharedBuffers\n",
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
" self.register_buffer(\"mask\", mask, persistent=False)\n",
"\n",
" self.register_buffer(\"cos\", cos, persistent=False)\n",
" self.register_buffer(\"sin\", sin, persistent=False)\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
"\n",
" queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n",
" keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
" values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
"\n",
" # Reshape queries, keys, and values\n",
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
" keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
"\n",
" # Transpose keys, values, and queries\n",
" keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
"\n",
" # Apply RoPE\n",
" keys = compute_rope(keys, self.cos, self.sin)\n",
" queries = compute_rope(queries, self.cos, self.sin)\n",
"\n",
" # Expand keys and values to match the number of heads\n",
" # Shape: (b, num_heads, num_tokens, head_dim)\n",
" keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" # For example, before repeat_interleave along dim=1 (query groups):\n",
" # [K1, K2]\n",
" # After repeat_interleave (each query group is repeated group_size times):\n",
" # [K1, K1, K2, K2]\n",
" # If we used regular repeat instead of repeat_interleave, we'd get:\n",
" # [K1, K2, K1, K2]\n",
"\n",
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
" # Shape: (b, num_heads, num_tokens, num_tokens)\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
"\n",
" # Original mask truncated to the number of tokens and converted to boolean\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
"\n",
" # Use the mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
"\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" assert keys.shape[-1] == self.head_dim\n",
"\n",
" # Shape: (b, num_tokens, num_heads, head_dim)\n",
" context_vec = (attn_weights @ values).transpose(1, 2)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n",
" context_vec = self.out_proj(context_vec) # optional projection\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
"metadata": {
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
},
"outputs": [],
"source": [
"class TransformerBlock(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.att = GroupedQueryAttention(\n",
" d_in=cfg[\"emb_dim\"],\n",
" d_out=cfg[\"emb_dim\"],\n",
" context_length=cfg[\"context_length\"],\n",
" num_heads=cfg[\"n_heads\"],\n",
" num_kv_groups=cfg[\"n_kv_groups\"],\n",
" rope_base=cfg[\"rope_base\"],\n",
" rope_config=cfg[\"rope_freq\"],\n",
" dtype=cfg[\"dtype\"]\n",
" )\n",
" self.ff = FeedForward(cfg)\n",
" self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
" self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
"\n",
" def forward(self, x):\n",
" # Shortcut connection for attention block\n",
" shortcut = x\n",
" x = self.norm1(x)\n",
" x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n",
" x = x + shortcut # Add the original input back\n",
"\n",
" # Shortcut connection for feed-forward block\n",
" shortcut = x\n",
" x = self.norm2(x)\n",
" x = self.ff(x.to(torch.bfloat16))\n",
" x = x + shortcut # Add the original input back\n",
"\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
"metadata": {
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
},
"outputs": [],
"source": [
"class Llama3Model(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
"\n",
" self.trf_blocks = nn.Sequential(\n",
" *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
"\n",
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" def forward(self, in_idx):\n",
" tok_embeds = self.tok_emb(in_idx)\n",
" x = tok_embeds\n",
" x = self.trf_blocks(x)\n",
" x = self.final_norm(x)\n",
" logits = self.out_head(x.to(torch.bfloat16))\n",
" return logits"
]
},
{
"cell_type": "markdown",
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
"metadata": {
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
},
"source": [
"&nbsp;\n",
"# 2. Initialize model"
]
},
{
"cell_type": "markdown",
"id": "23dea40c-fe20-4a75-be25-d6fce5863c01",
"metadata": {
"id": "23dea40c-fe20-4a75-be25-d6fce5863c01"
},
"source": [
"- The remainder of this notebook uses the Llama 3.2 1B model; to use the 3B model variant, just uncomment the second configuration file in the following code cell"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
"metadata": {
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
},
"outputs": [],
"source": [
"# Llama 3.2 1B\n",
"\n",
"LLAMA32_CONFIG = {\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # Context length\n",
" \"emb_dim\": 2048, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 16, # Number of layers\n",
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0,\n",
" \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n",
" \"original_context_length\": 8192,\n",
" }\n",
"}\n",
"\n",
"# Llama 3.2 3B\n",
"\n",
"# LLAMA32_CONFIG = {\n",
"# \"vocab_size\": 128_256, # Vocabulary size\n",
"# \"context_length\": 131_072, # Context length\n",
"# \"emb_dim\": 3072, # Embedding dimension\n",
"# \"n_heads\": 24, # Number of attention heads\n",
"# \"n_layers\": 28, # Number of layers\n",
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
"# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
"# \"rope_freq\": { # RoPE frequency scaling\n",
"# \"factor\": 32.0,\n",
"# \"low_freq_factor\": 1.0,\n",
"# \"high_freq_factor\": 4.0,\n",
"# \"original_context_length\": 8192,\n",
"# }\n",
"# }\n",
"\n",
"LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\""
]
},
{
"cell_type": "markdown",
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06",
"metadata": {
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06"
},
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c",
"metadata": {
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New RoPE theta: 31250.0\n"
]
}
],
"source": [
"old_context_length = LLAMA32_CONFIG[\"context_length\"]\n",
"LLAMA32_CONFIG[\"context_length\"] = 8192\n",
"\n",
"\n",
"def rescale_theta(theta_old, context_length_old, context_length_new):\n",
" scaling_factor = context_length_new / context_length_old\n",
" theta_new = theta_old * scaling_factor\n",
" return theta_new\n",
"\n",
"LLAMA32_CONFIG[\"rope_base\"] = rescale_theta(\n",
" LLAMA32_CONFIG[\"rope_base\"],\n",
" old_context_length,\n",
" LLAMA32_CONFIG[\"context_length\"]\n",
")\n",
"\n",
"print(\"New RoPE theta:\", LLAMA32_CONFIG[\"rope_base\"])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
"metadata": {
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
},
"outputs": [],
"source": [
"model = Llama3Model(LLAMA32_CONFIG)"
]
},
{
"cell_type": "markdown",
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1",
"metadata": {
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1"
},
"source": [
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
"outputId": "8efc4937-e616-40d0-cd59-670d7eb3e841"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"# Check buffers\n",
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"outputId": "00d7e983-262e-4c65-f322-f4d999311988"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of parameters: 1,498,482,688\n",
"\n",
"Total number of unique parameters: 1,235,814,400\n"
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters())\n",
"print(f\"Total number of parameters: {total_params:,}\")\n",
"\n",
"# Account for weight tying\n",
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float32 (PyTorch default): 11.42 GB\n",
"bfloat16: 5.71 GB\n"
]
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
" # Calculate total number of elements per parameter\n",
" param_size = param.numel()\n",
" total_params += param_size\n",
" # Check if gradients are stored for this parameter\n",
" if param.requires_grad:\n",
" total_grads += param_size\n",
"\n",
" # Calculate buffer size (non-parameters that require memory)\n",
" total_buffers = sum(buf.numel() for buf in model.buffers())\n",
"\n",
" # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
" # We assume parameters and gradients are stored in the same type as input dtype\n",
" element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
" total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
"\n",
" # Convert bytes to gigabytes\n",
" total_memory_gb = total_memory_bytes / (1024**3)\n",
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
"metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\")\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"\n",
"model.to(device);"
]
},
{
"cell_type": "markdown",
"id": "78e091e1-afa8-4d23-9aea-cced86181bfd",
"metadata": {
"id": "78e091e1-afa8-4d23-9aea-cced86181bfd"
},
"source": [
"&nbsp;\n",
"# 3. Load tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77",
"metadata": {
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77"
},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"\n",
"import tiktoken\n",
"from tiktoken.load import load_tiktoken_bpe\n",
"\n",
"\n",
"class Tokenizer:\n",
" def __init__(self, model_path):\n",
" assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
" mergeable_ranks = load_tiktoken_bpe(model_path)\n",
"\n",
" self.special_tokens = {\n",
" \"<|begin_of_text|>\": 128000,\n",
" \"<|end_of_text|>\": 128001,\n",
" \"<|start_header_id|>\": 128006,\n",
" \"<|end_header_id|>\": 128007,\n",
" \"<|eot_id|>\": 128009,\n",
" }\n",
" self.special_tokens.update({\n",
" f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n",
" })\n",
"\n",
" self.model = tiktoken.Encoding(\n",
" name=Path(model_path).name,\n",
" pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
" mergeable_ranks=mergeable_ranks,\n",
" special_tokens=self.special_tokens\n",
" )\n",
"\n",
"\n",
" def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n",
" if bos:\n",
" tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
" else:\n",
" tokens = []\n",
"\n",
" tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
"\n",
" if eos:\n",
" tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n",
" return tokens\n",
"\n",
" def decode(self, tokens):\n",
" return self.model.decode(tokens)\n",
"\n",
"\n",
"class ChatFormat:\n",
" def __init__(self, tokenizer):\n",
" self.tokenizer = tokenizer\n",
"\n",
" def encode_header(self, message):\n",
" tokens = []\n",
" tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n",
" tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n",
" tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
" tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
" return tokens\n",
"\n",
" def encode(self, text):\n",
" message = {\n",
" \"role\": \"user\",\n",
" \"content\": text\n",
" }\n",
"\n",
" tokens = self.encode_header(message)\n",
" tokens.extend(\n",
" self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
" )\n",
" tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
" return tokens\n",
"\n",
" def decode(self, token_ids):\n",
" return self.tokenizer.decode(token_ids)"
]
},
{
"cell_type": "markdown",
"id": "b771b60c-c198-4b30-bf10-42031197ae86",
"metadata": {
"id": "b771b60c-c198-4b30-bf10-42031197ae86"
},
"source": [
"- Please note that Meta AI requires that you accept the Llama 3.2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository to accept the terms\n",
"- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n",
"\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/settings.webp?1\" width=\"300px\">\n",
"\n",
"- Then, create and copy the access token so you can copy & paste it into the next code cell\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/access-token.webp?1\" width=\"600px\">"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
"outputId": "e6e6dc05-7330-45bc-a9a7-331919155bdd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
"Token is valid (permission: read).\n",
"Your token has been saved to /teamspace/studios/this_studio/.cache/huggingface/token\n",
"Login successful\n"
]
}
],
"source": [
"from huggingface_hub import login\n",
"\n",
"login()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "986bc1a0-804f-4154-80f8-44cefbee1368",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 141,
"referenced_widgets": [
"a1608feac06d4687967a3e398f01c489",
"518fb202e4b44aaba47f07d1a61b6762",
"672cdc5aea954de3af851c001a667ad3",
"eebf8874618746b39cf4a21a2728dc7f",
"5176834aa8784bba9ec21234b87a8948",
"e2dc407afcd945c798e30597fddfcb3c",
"0dccd57dcc5c43a588157cef957c07e8",
"33ca0cdf2c7f41598a381c4ebe6a4ee1",
"ee44487f58454dacb522b1e084ffb733",
"d2c41e71a3f441deaed091b620ac5603",
"3326b6141a1a4eba9f316df528a9b99a"
]
},
"id": "986bc1a0-804f-4154-80f8-44cefbee1368",
"outputId": "5dd7334b-4c71-465a-94d2-c3e95b9ddc58"
},
"outputs": [],
"source": [
"from huggingface_hub import hf_hub_download\n",
"\n",
"tokenizer_file_path = hf_hub_download(\n",
" repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
" filename=\"original/tokenizer.model\",\n",
" local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "_gBhxDtU_nxo",
"metadata": {
"id": "_gBhxDtU_nxo"
},
"outputs": [],
"source": [
"tokenizer = Tokenizer(tokenizer_file_path)\n",
"chat_tokenizer = ChatFormat(tokenizer)"
]
},
{
"cell_type": "markdown",
"id": "c172f89f-d301-439f-b809-46169e5f5945",
"metadata": {
"id": "c172f89f-d301-439f-b809-46169e5f5945"
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "75166128-5899-4995-9b88-9672e135650e",
"metadata": {
"id": "75166128-5899-4995-9b88-9672e135650e"
},
"outputs": [],
"source": [
"def assign(left, right, tensor_name=\"unknown\"):\n",
" if left.shape != right.shape:\n",
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
"\n",
" if isinstance(right, torch.Tensor):\n",
" return torch.nn.Parameter(right.clone().detach())\n",
" else:\n",
" return torch.nn.Parameter(torch.tensor(right))\n",
"\n",
"\n",
"def load_weights_into_llama(model, param_config, params):\n",
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
"\n",
" for l in range(param_config[\"n_layers\"]):\n",
"\n",
" # Load attention weights\n",
" model.trf_blocks[l].att.W_query.weight = assign(\n",
" model.trf_blocks[l].att.W_query.weight,\n",
" params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.q_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].att.W_key.weight = assign(\n",
" model.trf_blocks[l].att.W_key.weight,\n",
" params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.k_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].att.W_value.weight = assign(\n",
" model.trf_blocks[l].att.W_value.weight,\n",
" params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.v_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].att.out_proj.weight = assign(\n",
" model.trf_blocks[l].att.out_proj.weight,\n",
" params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.o_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].norm1.weight = assign(\n",
" model.trf_blocks[l].norm1.weight,\n",
" params[f\"model.layers.{l}.input_layernorm.weight\"],\n",
" f\"model.layers.{l}.input_layernorm.weight\"\n",
" )\n",
"\n",
" # Load FeedForward weights\n",
" model.trf_blocks[l].ff.fc1.weight = assign(\n",
" model.trf_blocks[l].ff.fc1.weight,\n",
" params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.gate_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].ff.fc2.weight = assign(\n",
" model.trf_blocks[l].ff.fc2.weight,\n",
" params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.up_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].ff.fc3.weight = assign(\n",
" model.trf_blocks[l].ff.fc3.weight,\n",
" params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.down_proj.weight\"\n",
" )\n",
" model.trf_blocks[l].norm2.weight = assign(\n",
" model.trf_blocks[l].norm2.weight,\n",
" params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
" f\"model.layers.{l}.post_attention_layernorm.weight\"\n",
" )\n",
"\n",
" # Load output layer weights\n",
" model.final_norm.weight = assign(model.final_norm.weight, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
"\n",
" if \"lm_head.weight\" in params.keys():\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
" else:\n",
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
" print(\"Model uses weight tying.\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17,
"referenced_widgets": [
"9881b6995c3f49dc89e6992fd9ab660b",
"17a3174e65c54476b2e0d1faf8f011ca",
"1bbf2e62c0754d1593beb4105a7f1ac1",
"b82112e1dec645d98aa1c1ba64abcb61",
"271e2bd6a35e4a8b92de8697f7c0be5f",
"90a79523187446dfa692723b2e5833a7",
"431ffb83b8c14bf182f0430e07ea6154",
"a8f1b72a33dd4b548de23fbd95e0da18",
"25cc36132d384189acfbecc59483134b",
"bfd06423ad544218968648016e731a46",
"d029630b63ff44cf807ade428d2eb421"
]
},
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model uses weight tying.\n"
]
}
],
"source": [
"from safetensors.torch import load_file\n",
"\n",
"\n",
"if LLAMA_SIZE_STR == \"1B\":\n",
" weights_file = hf_hub_download(\n",
" repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
" filename=\"model.safetensors\",\n",
" local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
" )\n",
" combined_weights = load_file(weights_file)\n",
"\n",
"\n",
"else:\n",
" combined_weights = {}\n",
" for i in range(1, 3):\n",
" weights_file = hf_hub_download(\n",
" repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
" filename=f\"model-0000{i}-of-00002.safetensors\",\n",
" local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
" )\n",
" current_weights = load_file(weights_file)\n",
" combined_weights.update(current_weights)\n",
"\n",
"\n",
"load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)\n",
"model.to(device)\n",
"del combined_weights # free up memory"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
"metadata": {
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Weight tying: True\n"
]
}
],
"source": [
"print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))"
]
},
{
"cell_type": "markdown",
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
"metadata": {
"id": "57d07df1-4401-4792-b549-7c4cc5632323"
},
"source": [
"&nbsp;\n",
"# 5. Generate text"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
},
"outputs": [],
"source": [
"def text_to_token_ids(text, tokenizer):\n",
" encoded = tokenizer.encode(text)\n",
" encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension\n",
" return encoded_tensor\n",
"\n",
"\n",
"def token_ids_to_text(token_ids, tokenizer):\n",
" flat = token_ids.squeeze(0) # remove batch dimension\n",
" return tokenizer.decode(flat.tolist())\n",
"\n",
"\n",
"def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n",
"\n",
" # For-loop is the same as before: Get logits, and only focus on last time step\n",
" for _ in range(max_new_tokens):\n",
" idx_cond = idx[:, -context_size:]\n",
" with torch.no_grad():\n",
" logits = model(idx_cond)\n",
" logits = logits[:, -1, :]\n",
"\n",
" # New: Filter logits with top_k sampling\n",
" if top_k is not None:\n",
" # Keep only top_k values\n",
" top_logits, _ = torch.topk(logits, top_k)\n",
" min_val = top_logits[:, -1]\n",
" logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
"\n",
" # New: Apply temperature scaling\n",
" if temperature > 0.0:\n",
" logits = logits / temperature\n",
"\n",
" # Apply softmax to get probabilities\n",
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
"\n",
" # Sample from the distribution\n",
" idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)\n",
"\n",
" # Otherwise same as before: get idx of the vocab entry with the highest logits value\n",
" else:\n",
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n",
"\n",
" if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n",
" break\n",
"\n",
" # Same as before: append sampled index to the running sequence\n",
" idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n",
"\n",
" return idx"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
" Llamas are herbivores, which means they primarily eat plants. Their diet consists mainly of:\n",
"\n",
"1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and grassy meadows.\n",
"2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n",
"3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas. It is high in protein and fiber.\n",
"4. Other plants: Llamas will also eat other plants, such as wild grasses, shrubs, and trees.\n",
"\n",
"It's worth noting that the diet of llamas can vary depending on the region, climate,\n"
]
}
],
"source": [
"PROMPT = \"What do llamas eat?\"\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"token_ids = generate(\n",
" model=model,\n",
" idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n",
" max_new_tokens=150,\n",
" context_size=LLAMA32_CONFIG[\"context_length\"],\n",
" top_k=1,\n",
" temperature=0.\n",
")\n",
"\n",
"output_text = token_ids_to_text(token_ids, tokenizer)\n",
"\n",
"\n",
"def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n",
" # Find the index of the first occurrence of \"<|end_header_id|>\"\n",
" index = text.find(header_end)\n",
"\n",
" if index != -1:\n",
" # Return the substring starting after \"<|end_header_id|>\"\n",
" return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace\n",
" else:\n",
" # If the token is not found, return the original text\n",
" return text\n",
"\n",
"print(\"Output text:\\n\", clean_text(output_text))"
]
},
{
"cell_type": "markdown",
"id": "549324d6-5c71-4147-ae21-2e67675faa3d",
"metadata": {
"id": "549324d6-5c71-4147-ae21-2e67675faa3d"
},
"source": [
"&nbsp;\n",
"# What's next?"
]
},
{
"cell_type": "markdown",
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
"metadata": {
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
},
"source": [
"- The notebook was kept purposefully minimal; if you are interested in additional explanation about the individual components, check out the following two companion notebooks:\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/gpt-and-all-llamas.webp\">\n",
"\n",
" 1. [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n",
" 2. [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n",
" \n",
"- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
"\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"0dccd57dcc5c43a588157cef957c07e8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
},
"17a3174e65c54476b2e0d1faf8f011ca": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_90a79523187446dfa692723b2e5833a7",
"placeholder": "",
"style": "IPY_MODEL_431ffb83b8c14bf182f0430e07ea6154",
"tabbable": null,
"tooltip": null,
"value": "model.safetensors:35%"
}
},
"1bbf2e62c0754d1593beb4105a7f1ac1": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_a8f1b72a33dd4b548de23fbd95e0da18",
"max": 2471645608,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_25cc36132d384189acfbecc59483134b",
"tabbable": null,
"tooltip": null,
"value": 880803840
}
},
"25cc36132d384189acfbecc59483134b": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"271e2bd6a35e4a8b92de8697f7c0be5f": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3326b6141a1a4eba9f316df528a9b99a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
},
"33ca0cdf2c7f41598a381c4ebe6a4ee1": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"431ffb83b8c14bf182f0430e07ea6154": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
},
"5176834aa8784bba9ec21234b87a8948": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"518fb202e4b44aaba47f07d1a61b6762": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_e2dc407afcd945c798e30597fddfcb3c",
"placeholder": "",
"style": "IPY_MODEL_0dccd57dcc5c43a588157cef957c07e8",
"tabbable": null,
"tooltip": null,
"value": "tokenizer.model:100%"
}
},
"672cdc5aea954de3af851c001a667ad3": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_33ca0cdf2c7f41598a381c4ebe6a4ee1",
"max": 2183982,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_ee44487f58454dacb522b1e084ffb733",
"tabbable": null,
"tooltip": null,
"value": 2183982
}
},
"90a79523187446dfa692723b2e5833a7": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9881b6995c3f49dc89e6992fd9ab660b": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_17a3174e65c54476b2e0d1faf8f011ca",
"IPY_MODEL_1bbf2e62c0754d1593beb4105a7f1ac1",
"IPY_MODEL_b82112e1dec645d98aa1c1ba64abcb61"
],
"layout": "IPY_MODEL_271e2bd6a35e4a8b92de8697f7c0be5f",
"tabbable": null,
"tooltip": null
}
},
"a1608feac06d4687967a3e398f01c489": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_518fb202e4b44aaba47f07d1a61b6762",
"IPY_MODEL_672cdc5aea954de3af851c001a667ad3",
"IPY_MODEL_eebf8874618746b39cf4a21a2728dc7f"
],
"layout": "IPY_MODEL_5176834aa8784bba9ec21234b87a8948",
"tabbable": null,
"tooltip": null
}
},
"a8f1b72a33dd4b548de23fbd95e0da18": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b82112e1dec645d98aa1c1ba64abcb61": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_bfd06423ad544218968648016e731a46",
"placeholder": "",
"style": "IPY_MODEL_d029630b63ff44cf807ade428d2eb421",
"tabbable": null,
"tooltip": null,
"value": "870M/2.47G[00:20&lt;00:37,42.8MB/s]"
}
},
"bfd06423ad544218968648016e731a46": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d029630b63ff44cf807ade428d2eb421": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
},
"d2c41e71a3f441deaed091b620ac5603": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e2dc407afcd945c798e30597fddfcb3c": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"ee44487f58454dacb522b1e084ffb733": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"eebf8874618746b39cf4a21a2728dc7f": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_d2c41e71a3f441deaed091b620ac5603",
"placeholder": "",
"style": "IPY_MODEL_3326b6141a1a4eba9f316df528a9b99a",
"tabbable": null,
"tooltip": null,
"value": "2.18M/2.18M[00:00&lt;00:00,9.47MB/s]"
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}