mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 02:50:15 +00:00
1789 lines
60 KiB
Plaintext
1789 lines
60 KiB
Plaintext
![]() |
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
|
|||
|
"metadata": {
|
|||
|
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"<table style=\"width:100%\">\n",
|
|||
|
"<tr>\n",
|
|||
|
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
|||
|
"<font size=\"2\">\n",
|
|||
|
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
|
|||
|
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
|
|||
|
"</font>\n",
|
|||
|
"</td>\n",
|
|||
|
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
|||
|
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
|
|||
|
"</td>\n",
|
|||
|
"</tr>\n",
|
|||
|
"</table>"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
|
|||
|
"metadata": {
|
|||
|
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# Qwen3 0.6B From Scratch (A Standalone Notebook)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
|
|||
|
"metadata": {
|
|||
|
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"- This notebook is purposefully minimal and focuses on the code to implement Qwen3 0.6B; for more information about this model, please see the original blog post and technical report:\n",
|
|||
|
" - [Qwen3: Think Deeper, Act Faster](https://qwenlm.github.io/blog/qwen3/)\n",
|
|||
|
" - [Qwen3 Technical Report](https://arxiv.org/abs/2505.09388) \n",
|
|||
|
"- Many architectural components in Qwen3 are similar to Llama 3; for a step-by-step guide that explains the individual components and the relationship between GPT and the components used here, you may like the GPT-to-Llama conversion notebooks:\n",
|
|||
|
" - [Converting a From-Scratch GPT Architecture to Llama 2](../07_gpt_to_llama/converting-gpt-to-llama2.ipynb)\n",
|
|||
|
" - [Converting Llama 2 to Llama 3.2 From Scratch](../07_gpt_to_llama/converting-llama2-to-llama3.ipynb)\n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen-overview.webp\">\n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
"- About the code:\n",
|
|||
|
" - all code is my own code, mapping the Qwen3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"id": "7c201adb-747e-437b-9a62-442802941e01",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
|
|||
|
"outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"huggingface_hub version: 0.33.0\n",
|
|||
|
"tokenizers version: 0.21.1\n",
|
|||
|
"torch version: 2.6.0\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from importlib.metadata import version\n",
|
|||
|
"\n",
|
|||
|
"pkgs = [\n",
|
|||
|
" \"huggingface_hub\", # to download pretrained weights\n",
|
|||
|
" \"tokenizers\", # to implement the tokenizer\n",
|
|||
|
" \"torch\", # to implement the model\n",
|
|||
|
"]\n",
|
|||
|
"for p in pkgs:\n",
|
|||
|
" print(f\"{p} version: {version(p)}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "07e96fbb-8e16-4f6d-835f-c6159321280b",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"- This notebook supports both the base model and the reasoning (\"thinking\") model; which model to use can be controlled via the following flag:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"id": "70a90338-624a-4706-aa55-6b4358070194",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"USE_REASONING_MODEL = True"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
|
|||
|
"metadata": {
|
|||
|
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
" \n",
|
|||
|
"# 1. Architecture code"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"id": "82076c21-9331-4dcd-b017-42b046cf1a60",
|
|||
|
"metadata": {
|
|||
|
"id": "82076c21-9331-4dcd-b017-42b046cf1a60"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"class FeedForward(nn.Module):\n",
|
|||
|
" def __init__(self, cfg):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
|
|||
|
" self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
|
|||
|
" self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x_fc1 = self.fc1(x)\n",
|
|||
|
" x_fc2 = self.fc2(x)\n",
|
|||
|
" x = nn.functional.silu(x_fc1) * x_fc2\n",
|
|||
|
" return self.fc3(x)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"id": "56715760-37e1-433e-89da-04864c139a9e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class RMSNorm(nn.Module):\n",
|
|||
|
" def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" self.eps = eps\n",
|
|||
|
" self.qwen3_compatible = qwen3_compatible\n",
|
|||
|
" self.scale = nn.Parameter(torch.ones(emb_dim))\n",
|
|||
|
" self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" input_dtype = x.dtype\n",
|
|||
|
"\n",
|
|||
|
" if self.qwen3_compatible:\n",
|
|||
|
" x = x.to(torch.float32)\n",
|
|||
|
"\n",
|
|||
|
" variance = x.pow(2).mean(dim=-1, keepdim=True)\n",
|
|||
|
" norm_x = x * torch.rsqrt(variance + self.eps)\n",
|
|||
|
" norm_x = norm_x * self.scale\n",
|
|||
|
"\n",
|
|||
|
" if self.shift is not None:\n",
|
|||
|
" norm_x = norm_x + self.shift\n",
|
|||
|
"\n",
|
|||
|
" return norm_x.to(input_dtype)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"id": "4b9a346f-5826-4083-9162-abd56afc03f0",
|
|||
|
"metadata": {
|
|||
|
"id": "4b9a346f-5826-4083-9162-abd56afc03f0"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):\n",
|
|||
|
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
|
|||
|
"\n",
|
|||
|
" # Compute the inverse frequencies\n",
|
|||
|
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n",
|
|||
|
"\n",
|
|||
|
" # Generate position indices\n",
|
|||
|
" positions = torch.arange(context_length, dtype=dtype)\n",
|
|||
|
"\n",
|
|||
|
" # Compute the angles\n",
|
|||
|
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
|
|||
|
"\n",
|
|||
|
" # Expand angles to match the head_dim\n",
|
|||
|
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
|
|||
|
"\n",
|
|||
|
" # Precompute sine and cosine\n",
|
|||
|
" cos = torch.cos(angles)\n",
|
|||
|
" sin = torch.sin(angles)\n",
|
|||
|
"\n",
|
|||
|
" return cos, sin\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def apply_rope(x, cos, sin):\n",
|
|||
|
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
|
|||
|
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
|
|||
|
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
|
|||
|
"\n",
|
|||
|
" # Split x into first half and second half\n",
|
|||
|
" x1 = x[..., : head_dim // 2] # First half\n",
|
|||
|
" x2 = x[..., head_dim // 2 :] # Second half\n",
|
|||
|
"\n",
|
|||
|
" # Adjust sin and cos shapes\n",
|
|||
|
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
|
|||
|
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
|
|||
|
"\n",
|
|||
|
" # Apply the rotary transformation\n",
|
|||
|
" rotated = torch.cat((-x2, x1), dim=-1)\n",
|
|||
|
" x_rotated = (x * cos) + (rotated * sin)\n",
|
|||
|
"\n",
|
|||
|
" # It's ok to use lower-precision after applying cos and sin rotation\n",
|
|||
|
" return x_rotated.to(dtype=x.dtype)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
|
|||
|
"metadata": {
|
|||
|
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class GroupedQueryAttention(nn.Module):\n",
|
|||
|
" def __init__(\n",
|
|||
|
" self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None\n",
|
|||
|
" ):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
|
|||
|
"\n",
|
|||
|
" self.num_heads = num_heads\n",
|
|||
|
" self.num_kv_groups = num_kv_groups\n",
|
|||
|
" self.group_size = num_heads // num_kv_groups\n",
|
|||
|
"\n",
|
|||
|
" if head_dim is None:\n",
|
|||
|
" assert d_in % num_heads == 0, \"`d_in` must be divisible by `num_heads` if `head_dim` is not set\"\n",
|
|||
|
" head_dim = d_in // num_heads\n",
|
|||
|
"\n",
|
|||
|
" self.head_dim = head_dim\n",
|
|||
|
" self.d_out = num_heads * head_dim\n",
|
|||
|
"\n",
|
|||
|
" self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)\n",
|
|||
|
" self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
|
|||
|
" self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
|
|||
|
"\n",
|
|||
|
" self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)\n",
|
|||
|
"\n",
|
|||
|
" if qk_norm:\n",
|
|||
|
" self.q_norm = RMSNorm(head_dim, eps=1e-6)\n",
|
|||
|
" self.k_norm = RMSNorm(head_dim, eps=1e-6)\n",
|
|||
|
" else:\n",
|
|||
|
" self.q_norm = self.k_norm = None\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x, mask, cos, sin):\n",
|
|||
|
" b, num_tokens, _ = x.shape\n",
|
|||
|
"\n",
|
|||
|
" # Apply projections\n",
|
|||
|
" queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)\n",
|
|||
|
" keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
|
|||
|
" values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
|
|||
|
"\n",
|
|||
|
" # Reshape\n",
|
|||
|
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
|
|||
|
" keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
|
|||
|
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
|
|||
|
"\n",
|
|||
|
" # Optional normalization\n",
|
|||
|
" if self.q_norm:\n",
|
|||
|
" queries = self.q_norm(queries)\n",
|
|||
|
" if self.k_norm:\n",
|
|||
|
" keys = self.k_norm(keys)\n",
|
|||
|
"\n",
|
|||
|
" # Apply RoPE\n",
|
|||
|
" queries = apply_rope(queries, cos, sin)\n",
|
|||
|
" keys = apply_rope(keys, cos, sin)\n",
|
|||
|
"\n",
|
|||
|
" # Expand K and V to match number of heads\n",
|
|||
|
" keys = keys.repeat_interleave(self.group_size, dim=1)\n",
|
|||
|
" values = values.repeat_interleave(self.group_size, dim=1)\n",
|
|||
|
"\n",
|
|||
|
" # Attention\n",
|
|||
|
" attn_scores = queries @ keys.transpose(2, 3)\n",
|
|||
|
" attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
|
|||
|
" attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)\n",
|
|||
|
"\n",
|
|||
|
" context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n",
|
|||
|
" return self.out_proj(context)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
|
|||
|
"metadata": {
|
|||
|
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class TransformerBlock(nn.Module):\n",
|
|||
|
" def __init__(self, cfg):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" self.att = GroupedQueryAttention(\n",
|
|||
|
" d_in=cfg[\"emb_dim\"],\n",
|
|||
|
" num_heads=cfg[\"n_heads\"],\n",
|
|||
|
" head_dim=cfg[\"head_dim\"],\n",
|
|||
|
" num_kv_groups=cfg[\"n_kv_groups\"],\n",
|
|||
|
" qk_norm=cfg[\"qk_norm\"],\n",
|
|||
|
" dtype=cfg[\"dtype\"]\n",
|
|||
|
" )\n",
|
|||
|
" self.ff = FeedForward(cfg)\n",
|
|||
|
" self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n",
|
|||
|
" self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x, mask, cos, sin):\n",
|
|||
|
" # Shortcut connection for attention block\n",
|
|||
|
" shortcut = x\n",
|
|||
|
" x = self.norm1(x)\n",
|
|||
|
" x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n",
|
|||
|
" x = x + shortcut # Add the original input back\n",
|
|||
|
"\n",
|
|||
|
" # Shortcut connection for feed-forward block\n",
|
|||
|
" shortcut = x\n",
|
|||
|
" x = self.norm2(x)\n",
|
|||
|
" x = self.ff(x)\n",
|
|||
|
" x = x + shortcut # Add the original input back\n",
|
|||
|
"\n",
|
|||
|
" return x"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
|
|||
|
"metadata": {
|
|||
|
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class Qwen3Model(nn.Module):\n",
|
|||
|
" def __init__(self, cfg):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
"\n",
|
|||
|
" # Main model parameters\n",
|
|||
|
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
|
|||
|
"\n",
|
|||
|
" self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n",
|
|||
|
" [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
|||
|
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
|||
|
"\n",
|
|||
|
" # Reusuable utilities\n",
|
|||
|
" if cfg[\"head_dim\"] is None:\n",
|
|||
|
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
|||
|
" else:\n",
|
|||
|
" head_dim = cfg[\"head_dim\"]\n",
|
|||
|
" cos, sin = compute_rope_params(\n",
|
|||
|
" head_dim=head_dim,\n",
|
|||
|
" theta_base=cfg[\"rope_base\"],\n",
|
|||
|
" context_length=cfg[\"context_length\"]\n",
|
|||
|
" )\n",
|
|||
|
" self.register_buffer(\"cos\", cos, persistent=False)\n",
|
|||
|
" self.register_buffer(\"sin\", sin, persistent=False)\n",
|
|||
|
" self.cfg = cfg\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, in_idx):\n",
|
|||
|
" # Forward pass\n",
|
|||
|
" tok_embeds = self.tok_emb(in_idx)\n",
|
|||
|
" x = tok_embeds\n",
|
|||
|
"\n",
|
|||
|
" num_tokens = x.shape[1]\n",
|
|||
|
" mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
|
|||
|
" \n",
|
|||
|
" for block in self.trf_blocks:\n",
|
|||
|
" x = block(x, mask, self.cos, self.sin)\n",
|
|||
|
" x = self.final_norm(x)\n",
|
|||
|
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
|
|||
|
" return logits"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
|
|||
|
"metadata": {
|
|||
|
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
" \n",
|
|||
|
"# 2. Initialize model"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
|
|||
|
"metadata": {
|
|||
|
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# Qwen3 0.6B\n",
|
|||
|
"\n",
|
|||
|
"QWEN3_CONFIG = {\n",
|
|||
|
" \"vocab_size\": 151_936, # Vocabulary size\n",
|
|||
|
" \"context_length\": 40_960, # Context length that was used to train the model\n",
|
|||
|
" \"emb_dim\": 1024, # Embedding dimension\n",
|
|||
|
" \"n_heads\": 16, # Number of attention heads\n",
|
|||
|
" \"n_layers\": 28, # Number of layers\n",
|
|||
|
" \"hidden_dim\": 3072, # Size of the intermediate dimension in FeedForward\n",
|
|||
|
" \"head_dim\": 128, # Size of the heads in GQA\n",
|
|||
|
" \"qk_norm\": True, # Whether to normalize queries and values in GQA\n",
|
|||
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|||
|
" \"rope_base\": 1_000_000.0, # The base in RoPE's \"theta\"\n",
|
|||
|
" \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
|
|||
|
"metadata": {
|
|||
|
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"torch.manual_seed(123)\n",
|
|||
|
"model = Qwen3Model(QWEN3_CONFIG)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"Qwen3Model(\n",
|
|||
|
" (tok_emb): Embedding(151936, 1024)\n",
|
|||
|
" (trf_blocks): ModuleList(\n",
|
|||
|
" (0-27): 28 x TransformerBlock(\n",
|
|||
|
" (att): GroupedQueryAttention(\n",
|
|||
|
" (W_query): Linear(in_features=1024, out_features=2048, bias=False)\n",
|
|||
|
" (W_key): Linear(in_features=1024, out_features=1024, bias=False)\n",
|
|||
|
" (W_value): Linear(in_features=1024, out_features=1024, bias=False)\n",
|
|||
|
" (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n",
|
|||
|
" (q_norm): RMSNorm()\n",
|
|||
|
" (k_norm): RMSNorm()\n",
|
|||
|
" )\n",
|
|||
|
" (ff): FeedForward(\n",
|
|||
|
" (fc1): Linear(in_features=1024, out_features=3072, bias=False)\n",
|
|||
|
" (fc2): Linear(in_features=1024, out_features=3072, bias=False)\n",
|
|||
|
" (fc3): Linear(in_features=3072, out_features=1024, bias=False)\n",
|
|||
|
" )\n",
|
|||
|
" (norm1): RMSNorm()\n",
|
|||
|
" (norm2): RMSNorm()\n",
|
|||
|
" )\n",
|
|||
|
" )\n",
|
|||
|
" (final_norm): RMSNorm()\n",
|
|||
|
" (out_head): Linear(in_features=1024, out_features=151936, bias=False)\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "90aca91d-4bee-45ce-993a-4ec5393abe2b",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"- A quick check that the forward pass works before continuing:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"tensor([[[-0.2256, -0.0164, -0.7070, ..., 0.4414, 0.1245, 1.0703],\n",
|
|||
|
" [-0.6602, 0.5352, -0.0718, ..., -0.0737, 0.5391, 0.3086],\n",
|
|||
|
" [-0.4785, -0.1562, 0.1045, ..., -0.2324, 0.2354, 0.6328]]],\n",
|
|||
|
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model(torch.tensor([1, 2, 3]).unsqueeze(0))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
|||
|
"outputId": "00d7e983-262e-4c65-f322-f4d999311988"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Total number of parameters: 751,632,384\n",
|
|||
|
"\n",
|
|||
|
"Total number of unique parameters: 596,049,920\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"total_params = sum(p.numel() for p in model.parameters())\n",
|
|||
|
"print(f\"Total number of parameters: {total_params:,}\")\n",
|
|||
|
"\n",
|
|||
|
"# Account for weight tying\n",
|
|||
|
"total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
|
|||
|
"print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
|||
|
"outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"float32 (PyTorch default): 5.64 GB\n",
|
|||
|
"bfloat16: 2.82 GB\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"def model_memory_size(model, input_dtype=torch.float32):\n",
|
|||
|
" total_params = 0\n",
|
|||
|
" total_grads = 0\n",
|
|||
|
" for param in model.parameters():\n",
|
|||
|
" # Calculate total number of elements per parameter\n",
|
|||
|
" param_size = param.numel()\n",
|
|||
|
" total_params += param_size\n",
|
|||
|
" # Check if gradients are stored for this parameter\n",
|
|||
|
" if param.requires_grad:\n",
|
|||
|
" total_grads += param_size\n",
|
|||
|
"\n",
|
|||
|
" # Calculate buffer size (non-parameters that require memory)\n",
|
|||
|
" total_buffers = sum(buf.numel() for buf in model.buffers())\n",
|
|||
|
"\n",
|
|||
|
" # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
|
|||
|
" # We assume parameters and gradients are stored in the same type as input dtype\n",
|
|||
|
" element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
|
|||
|
" total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
|
|||
|
"\n",
|
|||
|
" # Convert bytes to gigabytes\n",
|
|||
|
" total_memory_gb = total_memory_bytes / (1024**3)\n",
|
|||
|
"\n",
|
|||
|
" return total_memory_gb\n",
|
|||
|
"\n",
|
|||
|
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
|
|||
|
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
|
|||
|
"metadata": {
|
|||
|
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"if torch.cuda.is_available():\n",
|
|||
|
" device = torch.device(\"cuda\")\n",
|
|||
|
"elif torch.backends.mps.is_available():\n",
|
|||
|
" device = torch.device(\"mps\")\n",
|
|||
|
"else:\n",
|
|||
|
" device = torch.device(\"cpu\")\n",
|
|||
|
"\n",
|
|||
|
"model.to(device);"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "c172f89f-d301-439f-b809-46169e5f5945",
|
|||
|
"metadata": {
|
|||
|
"id": "c172f89f-d301-439f-b809-46169e5f5945"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
" \n",
|
|||
|
"# 4. Load pretrained weights"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"id": "75166128-5899-4995-9b88-9672e135650e",
|
|||
|
"metadata": {
|
|||
|
"id": "75166128-5899-4995-9b88-9672e135650e"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def load_weights_into_qwen(model, param_config, params):\n",
|
|||
|
" def assign(left, right, tensor_name=\"unknown\"):\n",
|
|||
|
" if left.shape != right.shape:\n",
|
|||
|
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
|||
|
" return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
|
|||
|
"\n",
|
|||
|
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
|
|||
|
"\n",
|
|||
|
" for l in range(param_config[\"n_layers\"]):\n",
|
|||
|
" block = model.trf_blocks[l]\n",
|
|||
|
" att = block.att\n",
|
|||
|
"\n",
|
|||
|
" # Q, K, V projections\n",
|
|||
|
" att.W_query.weight = assign(\n",
|
|||
|
" att.W_query.weight,\n",
|
|||
|
" params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.self_attn.q_proj.weight\"\n",
|
|||
|
" )\n",
|
|||
|
" att.W_key.weight = assign(\n",
|
|||
|
" att.W_key.weight,\n",
|
|||
|
" params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.self_attn.k_proj.weight\"\n",
|
|||
|
" )\n",
|
|||
|
" att.W_value.weight = assign(\n",
|
|||
|
" att.W_value.weight,\n",
|
|||
|
" params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.self_attn.v_proj.weight\"\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" # Output projection\n",
|
|||
|
" att.out_proj.weight = assign(\n",
|
|||
|
" att.out_proj.weight,\n",
|
|||
|
" params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.self_attn.o_proj.weight\"\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" # QK norms\n",
|
|||
|
" if hasattr(att, \"q_norm\") and att.q_norm is not None:\n",
|
|||
|
" att.q_norm.scale = assign(\n",
|
|||
|
" att.q_norm.scale,\n",
|
|||
|
" params[f\"model.layers.{l}.self_attn.q_norm.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.self_attn.q_norm.weight\"\n",
|
|||
|
" )\n",
|
|||
|
" if hasattr(att, \"k_norm\") and att.k_norm is not None:\n",
|
|||
|
" att.k_norm.scale = assign(\n",
|
|||
|
" att.k_norm.scale,\n",
|
|||
|
" params[f\"model.layers.{l}.self_attn.k_norm.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.self_attn.k_norm.weight\"\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" # Attention layernorm\n",
|
|||
|
" block.norm1.scale = assign(\n",
|
|||
|
" block.norm1.scale,\n",
|
|||
|
" params[f\"model.layers.{l}.input_layernorm.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.input_layernorm.weight\"\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" # Feedforward weights\n",
|
|||
|
" block.ff.fc1.weight = assign(\n",
|
|||
|
" block.ff.fc1.weight,\n",
|
|||
|
" params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.mlp.gate_proj.weight\"\n",
|
|||
|
" )\n",
|
|||
|
" block.ff.fc2.weight = assign(\n",
|
|||
|
" block.ff.fc2.weight,\n",
|
|||
|
" params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.mlp.up_proj.weight\"\n",
|
|||
|
" )\n",
|
|||
|
" block.ff.fc3.weight = assign(\n",
|
|||
|
" block.ff.fc3.weight,\n",
|
|||
|
" params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.mlp.down_proj.weight\"\n",
|
|||
|
" )\n",
|
|||
|
" block.norm2.scale = assign(\n",
|
|||
|
" block.norm2.scale,\n",
|
|||
|
" params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
|
|||
|
" f\"model.layers.{l}.post_attention_layernorm.weight\"\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" # Final normalization and output head\n",
|
|||
|
" model.final_norm.scale = assign(model.final_norm.scale, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
|
|||
|
"\n",
|
|||
|
" # Model uses weight tying, hence we reuse the embedding layer weights here\n",
|
|||
|
" model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/",
|
|||
|
"height": 17,
|
|||
|
"referenced_widgets": [
|
|||
|
"9881b6995c3f49dc89e6992fd9ab660b",
|
|||
|
"17a3174e65c54476b2e0d1faf8f011ca",
|
|||
|
"1bbf2e62c0754d1593beb4105a7f1ac1",
|
|||
|
"b82112e1dec645d98aa1c1ba64abcb61",
|
|||
|
"271e2bd6a35e4a8b92de8697f7c0be5f",
|
|||
|
"90a79523187446dfa692723b2e5833a7",
|
|||
|
"431ffb83b8c14bf182f0430e07ea6154",
|
|||
|
"a8f1b72a33dd4b548de23fbd95e0da18",
|
|||
|
"25cc36132d384189acfbecc59483134b",
|
|||
|
"bfd06423ad544218968648016e731a46",
|
|||
|
"d029630b63ff44cf807ade428d2eb421"
|
|||
|
]
|
|||
|
},
|
|||
|
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
|||
|
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from pathlib import Path\n",
|
|||
|
"from safetensors.torch import load_file\n",
|
|||
|
"from huggingface_hub import hf_hub_download\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"if USE_REASONING_MODEL:\n",
|
|||
|
" repo_id = \"Qwen/Qwen3-0.6B\"\n",
|
|||
|
"else:\n",
|
|||
|
" repo_id = \"Qwen/Qwen3-0.6B-Base\"\n",
|
|||
|
"local_dir = Path(repo_id).parts[-1]\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"weights_file = hf_hub_download(\n",
|
|||
|
" repo_id=repo_id,\n",
|
|||
|
" filename=\"model.safetensors\",\n",
|
|||
|
" local_dir=local_dir\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"weights_dict = load_file(weights_file)\n",
|
|||
|
"\n",
|
|||
|
"load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)\n",
|
|||
|
"model.to(device)\n",
|
|||
|
"del weights_file # free up memory"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "6b345491-3510-4397-92d3-cd0a3fa3deee",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
" \n",
|
|||
|
"# 4. Load tokenizer"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from tokenizers import Tokenizer\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"class Qwen3Tokenizer():\n",
|
|||
|
" def __init__(self, tokenizer_file_path=\"tokenizer.json\", repo_id=None, add_generation_prompt=False, add_thinking=False):\n",
|
|||
|
" self.tokenizer_file_path = tokenizer_file_path\n",
|
|||
|
" self.add_generation_prompt = add_generation_prompt\n",
|
|||
|
" self.add_thinking = add_thinking\n",
|
|||
|
"\n",
|
|||
|
" tokenizer_file_path_obj = Path(tokenizer_file_path)\n",
|
|||
|
" if not tokenizer_file_path_obj.is_file() and repo_id is not None:\n",
|
|||
|
" _ = hf_hub_download(\n",
|
|||
|
" repo_id=repo_id,\n",
|
|||
|
" filename=str(tokenizer_file_path_obj.name),\n",
|
|||
|
" local_dir=str(tokenizer_file_path_obj.parent.name)\n",
|
|||
|
" )\n",
|
|||
|
" self.tokenizer = Tokenizer.from_file(tokenizer_file_path)\n",
|
|||
|
"\n",
|
|||
|
" def encode(self, prompt):\n",
|
|||
|
" messages = [\n",
|
|||
|
" {\"role\": \"user\", \"content\": prompt}\n",
|
|||
|
" ] \n",
|
|||
|
" formatted_prompt = self.format_qwen_chat(\n",
|
|||
|
" messages,\n",
|
|||
|
" add_generation_prompt=self.add_generation_prompt,\n",
|
|||
|
" add_thinking=self.add_thinking\n",
|
|||
|
" )\n",
|
|||
|
" return self.tokenizer.encode(formatted_prompt).ids\n",
|
|||
|
" \n",
|
|||
|
" def decode(self, token_ids):\n",
|
|||
|
" return self.tokenizer.decode(token_ids, skip_special_tokens=False)\n",
|
|||
|
" \n",
|
|||
|
" @staticmethod\n",
|
|||
|
" def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):\n",
|
|||
|
" prompt = \"\"\n",
|
|||
|
" for msg in messages:\n",
|
|||
|
" prompt += f\"<|im_start|>{msg['role']}\\n{msg['content']}<|im_end|>\\n\"\n",
|
|||
|
" if add_generation_prompt:\n",
|
|||
|
" prompt += \"<|im_start|>assistant\"\n",
|
|||
|
" if not add_thinking:\n",
|
|||
|
" prompt += \"<|think>\\n\\n<|/think>\\n\\n\"\n",
|
|||
|
" else:\n",
|
|||
|
" prompt += \"\\n\" \n",
|
|||
|
" return prompt"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"if USE_REASONING_MODEL:\n",
|
|||
|
" tokenizer_file_path = \"Qwen3-0.6B/tokenizer.json\"\n",
|
|||
|
"else:\n",
|
|||
|
" tokenizer_file_path = \"Qwen3-0.6B-Base/tokenizer.json\"\n",
|
|||
|
"\n",
|
|||
|
"tokenizer = Qwen3Tokenizer(\n",
|
|||
|
" tokenizer_file_path=tokenizer_file_path,\n",
|
|||
|
" repo_id=repo_id,\n",
|
|||
|
" add_generation_prompt=USE_REASONING_MODEL,\n",
|
|||
|
" add_thinking=USE_REASONING_MODEL\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"id": "1946b534-e3af-431a-a222-391a60bfa892",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'<|im_start|>user\\nGive me a short introduction to large language models.<|im_end|>\\n<|im_start|>assistant\\n'"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 21,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"prompt = \"Give me a short introduction to large language models.\"\n",
|
|||
|
"\n",
|
|||
|
"input_token_ids = tokenizer.encode(prompt)\n",
|
|||
|
"text = tokenizer.decode(input_token_ids)\n",
|
|||
|
"text"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
|
|||
|
"metadata": {
|
|||
|
"id": "57d07df1-4401-4792-b549-7c4cc5632323"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
" \n",
|
|||
|
"# 5. Generate text"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
|||
|
"metadata": {
|
|||
|
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# Identical function from chapter 5\n",
|
|||
|
"\n",
|
|||
|
"def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n",
|
|||
|
" # For-loop is the same as before: Get logits, and only focus on last time step\n",
|
|||
|
" for _ in range(max_new_tokens):\n",
|
|||
|
" idx_cond = idx[:, -context_size:]\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" logits = model(idx_cond)\n",
|
|||
|
" logits = logits[:, -1, :]\n",
|
|||
|
"\n",
|
|||
|
" # Filter logits with top_k sampling\n",
|
|||
|
" if top_k is not None:\n",
|
|||
|
" # Keep only top_k values\n",
|
|||
|
" top_logits, _ = torch.topk(logits, top_k)\n",
|
|||
|
" min_val = top_logits[:, -1]\n",
|
|||
|
" logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
|
|||
|
"\n",
|
|||
|
" # pply temperature scaling\n",
|
|||
|
" if temperature > 0.0:\n",
|
|||
|
" logits = logits / temperature\n",
|
|||
|
"\n",
|
|||
|
" # Apply softmax to get probabilities\n",
|
|||
|
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
|
|||
|
"\n",
|
|||
|
" # Sample from the distribution\n",
|
|||
|
" idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)\n",
|
|||
|
"\n",
|
|||
|
" # Otherwise same as before: get idx of the vocab entry with the highest logits value\n",
|
|||
|
" else:\n",
|
|||
|
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n",
|
|||
|
"\n",
|
|||
|
" if eos_id is not None and idx_next.item() == eos_id:\n",
|
|||
|
" break # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n",
|
|||
|
"\n",
|
|||
|
" # Same as before: append sampled index to the running sequence\n",
|
|||
|
" idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n",
|
|||
|
"\n",
|
|||
|
" return idx"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
|||
|
"metadata": {
|
|||
|
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Time: 21.96 sec\n",
|
|||
|
"<|im_start|>user\n",
|
|||
|
"Give me a short introduction to large language models.<|im_end|>\n",
|
|||
|
"<|im_start|>assistant\n",
|
|||
|
"<think>\n",
|
|||
|
"Okay, the user wants a short introduction to large language models. Let me start by recalling what I know. Large language models are AI systems that can understand and generate human language. They're trained on massive datasets, so they can learn complex patterns and nuances.\n",
|
|||
|
"\n",
|
|||
|
"I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or language assistants. Also, emphasize their adaptability and versatility. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around 3-4 sentences. Make sure it's clear and easy to understand.\n",
|
|||
|
"</think>\n",
|
|||
|
"\n",
|
|||
|
"Large language models (LLMs) are AI systems designed...\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import time\n",
|
|||
|
"\n",
|
|||
|
"torch.manual_seed(123)\n",
|
|||
|
"\n",
|
|||
|
"start = time.time()\n",
|
|||
|
"\n",
|
|||
|
"output_token_ids = generate(\n",
|
|||
|
" model=model,\n",
|
|||
|
" idx=torch.tensor(input_token_ids, device=device).unsqueeze(0),\n",
|
|||
|
" max_new_tokens=150,\n",
|
|||
|
" context_size=QWEN3_CONFIG[\"context_length\"],\n",
|
|||
|
" top_k=1,\n",
|
|||
|
" temperature=0.\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"Time: {time.time() - start:.2f} sec\")\n",
|
|||
|
"\n",
|
|||
|
"if torch.cuda.is_available():\n",
|
|||
|
" max_mem_bytes = torch.cuda.max_memory_allocated()\n",
|
|||
|
" max_mem_gb = max_mem_bytes / (1024 ** 3)\n",
|
|||
|
" print(f\"Max memory allocated: {max_mem_gb:.2f} GB\")\n",
|
|||
|
"\n",
|
|||
|
"output_text = tokenizer.decode(output_token_ids.squeeze(0).tolist())\n",
|
|||
|
"\n",
|
|||
|
"print(output_text + \"...\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "549324d6-5c71-4147-ae21-2e67675faa3d",
|
|||
|
"metadata": {
|
|||
|
"id": "549324d6-5c71-4147-ae21-2e67675faa3d"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
" \n",
|
|||
|
"# What's next?"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
|
|||
|
"metadata": {
|
|||
|
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"- Check out the [README.md](./README.md), to use this model via the `llms_from_scratch` package\n",
|
|||
|
"- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
|
|||
|
"\n",
|
|||
|
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"accelerator": "GPU",
|
|||
|
"colab": {
|
|||
|
"gpuType": "A100",
|
|||
|
"provenance": []
|
|||
|
},
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.10.16"
|
|||
|
},
|
|||
|
"widgets": {
|
|||
|
"application/vnd.jupyter.widget-state+json": {
|
|||
|
"0dccd57dcc5c43a588157cef957c07e8": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HTMLStyleModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HTMLStyleModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "StyleView",
|
|||
|
"background": null,
|
|||
|
"description_width": "",
|
|||
|
"font_size": null,
|
|||
|
"text_color": null
|
|||
|
}
|
|||
|
},
|
|||
|
"17a3174e65c54476b2e0d1faf8f011ca": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HTMLModel",
|
|||
|
"state": {
|
|||
|
"_dom_classes": [],
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HTMLModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "HTMLView",
|
|||
|
"description": "",
|
|||
|
"description_allow_html": false,
|
|||
|
"layout": "IPY_MODEL_90a79523187446dfa692723b2e5833a7",
|
|||
|
"placeholder": "",
|
|||
|
"style": "IPY_MODEL_431ffb83b8c14bf182f0430e07ea6154",
|
|||
|
"tabbable": null,
|
|||
|
"tooltip": null,
|
|||
|
"value": "model.safetensors: 35%"
|
|||
|
}
|
|||
|
},
|
|||
|
"1bbf2e62c0754d1593beb4105a7f1ac1": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "FloatProgressModel",
|
|||
|
"state": {
|
|||
|
"_dom_classes": [],
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "FloatProgressModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "ProgressView",
|
|||
|
"bar_style": "",
|
|||
|
"description": "",
|
|||
|
"description_allow_html": false,
|
|||
|
"layout": "IPY_MODEL_a8f1b72a33dd4b548de23fbd95e0da18",
|
|||
|
"max": 2471645608,
|
|||
|
"min": 0,
|
|||
|
"orientation": "horizontal",
|
|||
|
"style": "IPY_MODEL_25cc36132d384189acfbecc59483134b",
|
|||
|
"tabbable": null,
|
|||
|
"tooltip": null,
|
|||
|
"value": 880803840
|
|||
|
}
|
|||
|
},
|
|||
|
"25cc36132d384189acfbecc59483134b": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "ProgressStyleModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "ProgressStyleModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "StyleView",
|
|||
|
"bar_color": null,
|
|||
|
"description_width": ""
|
|||
|
}
|
|||
|
},
|
|||
|
"271e2bd6a35e4a8b92de8697f7c0be5f": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"align_content": null,
|
|||
|
"align_items": null,
|
|||
|
"align_self": null,
|
|||
|
"border_bottom": null,
|
|||
|
"border_left": null,
|
|||
|
"border_right": null,
|
|||
|
"border_top": null,
|
|||
|
"bottom": null,
|
|||
|
"display": null,
|
|||
|
"flex": null,
|
|||
|
"flex_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_column": null,
|
|||
|
"grid_gap": null,
|
|||
|
"grid_row": null,
|
|||
|
"grid_template_areas": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"grid_template_rows": null,
|
|||
|
"height": null,
|
|||
|
"justify_content": null,
|
|||
|
"justify_items": null,
|
|||
|
"left": null,
|
|||
|
"margin": null,
|
|||
|
"max_height": null,
|
|||
|
"max_width": null,
|
|||
|
"min_height": null,
|
|||
|
"min_width": null,
|
|||
|
"object_fit": null,
|
|||
|
"object_position": null,
|
|||
|
"order": null,
|
|||
|
"overflow": null,
|
|||
|
"padding": null,
|
|||
|
"right": null,
|
|||
|
"top": null,
|
|||
|
"visibility": null,
|
|||
|
"width": null
|
|||
|
}
|
|||
|
},
|
|||
|
"3326b6141a1a4eba9f316df528a9b99a": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HTMLStyleModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HTMLStyleModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "StyleView",
|
|||
|
"background": null,
|
|||
|
"description_width": "",
|
|||
|
"font_size": null,
|
|||
|
"text_color": null
|
|||
|
}
|
|||
|
},
|
|||
|
"33ca0cdf2c7f41598a381c4ebe6a4ee1": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"align_content": null,
|
|||
|
"align_items": null,
|
|||
|
"align_self": null,
|
|||
|
"border_bottom": null,
|
|||
|
"border_left": null,
|
|||
|
"border_right": null,
|
|||
|
"border_top": null,
|
|||
|
"bottom": null,
|
|||
|
"display": null,
|
|||
|
"flex": null,
|
|||
|
"flex_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_column": null,
|
|||
|
"grid_gap": null,
|
|||
|
"grid_row": null,
|
|||
|
"grid_template_areas": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"grid_template_rows": null,
|
|||
|
"height": null,
|
|||
|
"justify_content": null,
|
|||
|
"justify_items": null,
|
|||
|
"left": null,
|
|||
|
"margin": null,
|
|||
|
"max_height": null,
|
|||
|
"max_width": null,
|
|||
|
"min_height": null,
|
|||
|
"min_width": null,
|
|||
|
"object_fit": null,
|
|||
|
"object_position": null,
|
|||
|
"order": null,
|
|||
|
"overflow": null,
|
|||
|
"padding": null,
|
|||
|
"right": null,
|
|||
|
"top": null,
|
|||
|
"visibility": null,
|
|||
|
"width": null
|
|||
|
}
|
|||
|
},
|
|||
|
"431ffb83b8c14bf182f0430e07ea6154": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HTMLStyleModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HTMLStyleModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "StyleView",
|
|||
|
"background": null,
|
|||
|
"description_width": "",
|
|||
|
"font_size": null,
|
|||
|
"text_color": null
|
|||
|
}
|
|||
|
},
|
|||
|
"5176834aa8784bba9ec21234b87a8948": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"align_content": null,
|
|||
|
"align_items": null,
|
|||
|
"align_self": null,
|
|||
|
"border_bottom": null,
|
|||
|
"border_left": null,
|
|||
|
"border_right": null,
|
|||
|
"border_top": null,
|
|||
|
"bottom": null,
|
|||
|
"display": null,
|
|||
|
"flex": null,
|
|||
|
"flex_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_column": null,
|
|||
|
"grid_gap": null,
|
|||
|
"grid_row": null,
|
|||
|
"grid_template_areas": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"grid_template_rows": null,
|
|||
|
"height": null,
|
|||
|
"justify_content": null,
|
|||
|
"justify_items": null,
|
|||
|
"left": null,
|
|||
|
"margin": null,
|
|||
|
"max_height": null,
|
|||
|
"max_width": null,
|
|||
|
"min_height": null,
|
|||
|
"min_width": null,
|
|||
|
"object_fit": null,
|
|||
|
"object_position": null,
|
|||
|
"order": null,
|
|||
|
"overflow": null,
|
|||
|
"padding": null,
|
|||
|
"right": null,
|
|||
|
"top": null,
|
|||
|
"visibility": null,
|
|||
|
"width": null
|
|||
|
}
|
|||
|
},
|
|||
|
"518fb202e4b44aaba47f07d1a61b6762": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HTMLModel",
|
|||
|
"state": {
|
|||
|
"_dom_classes": [],
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HTMLModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "HTMLView",
|
|||
|
"description": "",
|
|||
|
"description_allow_html": false,
|
|||
|
"layout": "IPY_MODEL_e2dc407afcd945c798e30597fddfcb3c",
|
|||
|
"placeholder": "",
|
|||
|
"style": "IPY_MODEL_0dccd57dcc5c43a588157cef957c07e8",
|
|||
|
"tabbable": null,
|
|||
|
"tooltip": null,
|
|||
|
"value": "tokenizer.model: 100%"
|
|||
|
}
|
|||
|
},
|
|||
|
"672cdc5aea954de3af851c001a667ad3": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "FloatProgressModel",
|
|||
|
"state": {
|
|||
|
"_dom_classes": [],
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "FloatProgressModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "ProgressView",
|
|||
|
"bar_style": "success",
|
|||
|
"description": "",
|
|||
|
"description_allow_html": false,
|
|||
|
"layout": "IPY_MODEL_33ca0cdf2c7f41598a381c4ebe6a4ee1",
|
|||
|
"max": 2183982,
|
|||
|
"min": 0,
|
|||
|
"orientation": "horizontal",
|
|||
|
"style": "IPY_MODEL_ee44487f58454dacb522b1e084ffb733",
|
|||
|
"tabbable": null,
|
|||
|
"tooltip": null,
|
|||
|
"value": 2183982
|
|||
|
}
|
|||
|
},
|
|||
|
"90a79523187446dfa692723b2e5833a7": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"align_content": null,
|
|||
|
"align_items": null,
|
|||
|
"align_self": null,
|
|||
|
"border_bottom": null,
|
|||
|
"border_left": null,
|
|||
|
"border_right": null,
|
|||
|
"border_top": null,
|
|||
|
"bottom": null,
|
|||
|
"display": null,
|
|||
|
"flex": null,
|
|||
|
"flex_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_column": null,
|
|||
|
"grid_gap": null,
|
|||
|
"grid_row": null,
|
|||
|
"grid_template_areas": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"grid_template_rows": null,
|
|||
|
"height": null,
|
|||
|
"justify_content": null,
|
|||
|
"justify_items": null,
|
|||
|
"left": null,
|
|||
|
"margin": null,
|
|||
|
"max_height": null,
|
|||
|
"max_width": null,
|
|||
|
"min_height": null,
|
|||
|
"min_width": null,
|
|||
|
"object_fit": null,
|
|||
|
"object_position": null,
|
|||
|
"order": null,
|
|||
|
"overflow": null,
|
|||
|
"padding": null,
|
|||
|
"right": null,
|
|||
|
"top": null,
|
|||
|
"visibility": null,
|
|||
|
"width": null
|
|||
|
}
|
|||
|
},
|
|||
|
"9881b6995c3f49dc89e6992fd9ab660b": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HBoxModel",
|
|||
|
"state": {
|
|||
|
"_dom_classes": [],
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HBoxModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "HBoxView",
|
|||
|
"box_style": "",
|
|||
|
"children": [
|
|||
|
"IPY_MODEL_17a3174e65c54476b2e0d1faf8f011ca",
|
|||
|
"IPY_MODEL_1bbf2e62c0754d1593beb4105a7f1ac1",
|
|||
|
"IPY_MODEL_b82112e1dec645d98aa1c1ba64abcb61"
|
|||
|
],
|
|||
|
"layout": "IPY_MODEL_271e2bd6a35e4a8b92de8697f7c0be5f",
|
|||
|
"tabbable": null,
|
|||
|
"tooltip": null
|
|||
|
}
|
|||
|
},
|
|||
|
"a1608feac06d4687967a3e398f01c489": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HBoxModel",
|
|||
|
"state": {
|
|||
|
"_dom_classes": [],
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HBoxModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "HBoxView",
|
|||
|
"box_style": "",
|
|||
|
"children": [
|
|||
|
"IPY_MODEL_518fb202e4b44aaba47f07d1a61b6762",
|
|||
|
"IPY_MODEL_672cdc5aea954de3af851c001a667ad3",
|
|||
|
"IPY_MODEL_eebf8874618746b39cf4a21a2728dc7f"
|
|||
|
],
|
|||
|
"layout": "IPY_MODEL_5176834aa8784bba9ec21234b87a8948",
|
|||
|
"tabbable": null,
|
|||
|
"tooltip": null
|
|||
|
}
|
|||
|
},
|
|||
|
"a8f1b72a33dd4b548de23fbd95e0da18": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"align_content": null,
|
|||
|
"align_items": null,
|
|||
|
"align_self": null,
|
|||
|
"border_bottom": null,
|
|||
|
"border_left": null,
|
|||
|
"border_right": null,
|
|||
|
"border_top": null,
|
|||
|
"bottom": null,
|
|||
|
"display": null,
|
|||
|
"flex": null,
|
|||
|
"flex_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_column": null,
|
|||
|
"grid_gap": null,
|
|||
|
"grid_row": null,
|
|||
|
"grid_template_areas": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"grid_template_rows": null,
|
|||
|
"height": null,
|
|||
|
"justify_content": null,
|
|||
|
"justify_items": null,
|
|||
|
"left": null,
|
|||
|
"margin": null,
|
|||
|
"max_height": null,
|
|||
|
"max_width": null,
|
|||
|
"min_height": null,
|
|||
|
"min_width": null,
|
|||
|
"object_fit": null,
|
|||
|
"object_position": null,
|
|||
|
"order": null,
|
|||
|
"overflow": null,
|
|||
|
"padding": null,
|
|||
|
"right": null,
|
|||
|
"top": null,
|
|||
|
"visibility": null,
|
|||
|
"width": null
|
|||
|
}
|
|||
|
},
|
|||
|
"b82112e1dec645d98aa1c1ba64abcb61": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HTMLModel",
|
|||
|
"state": {
|
|||
|
"_dom_classes": [],
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HTMLModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "HTMLView",
|
|||
|
"description": "",
|
|||
|
"description_allow_html": false,
|
|||
|
"layout": "IPY_MODEL_bfd06423ad544218968648016e731a46",
|
|||
|
"placeholder": "",
|
|||
|
"style": "IPY_MODEL_d029630b63ff44cf807ade428d2eb421",
|
|||
|
"tabbable": null,
|
|||
|
"tooltip": null,
|
|||
|
"value": " 870M/2.47G [00:20<00:37, 42.8MB/s]"
|
|||
|
}
|
|||
|
},
|
|||
|
"bfd06423ad544218968648016e731a46": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"align_content": null,
|
|||
|
"align_items": null,
|
|||
|
"align_self": null,
|
|||
|
"border_bottom": null,
|
|||
|
"border_left": null,
|
|||
|
"border_right": null,
|
|||
|
"border_top": null,
|
|||
|
"bottom": null,
|
|||
|
"display": null,
|
|||
|
"flex": null,
|
|||
|
"flex_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_column": null,
|
|||
|
"grid_gap": null,
|
|||
|
"grid_row": null,
|
|||
|
"grid_template_areas": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"grid_template_rows": null,
|
|||
|
"height": null,
|
|||
|
"justify_content": null,
|
|||
|
"justify_items": null,
|
|||
|
"left": null,
|
|||
|
"margin": null,
|
|||
|
"max_height": null,
|
|||
|
"max_width": null,
|
|||
|
"min_height": null,
|
|||
|
"min_width": null,
|
|||
|
"object_fit": null,
|
|||
|
"object_position": null,
|
|||
|
"order": null,
|
|||
|
"overflow": null,
|
|||
|
"padding": null,
|
|||
|
"right": null,
|
|||
|
"top": null,
|
|||
|
"visibility": null,
|
|||
|
"width": null
|
|||
|
}
|
|||
|
},
|
|||
|
"d029630b63ff44cf807ade428d2eb421": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HTMLStyleModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HTMLStyleModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "StyleView",
|
|||
|
"background": null,
|
|||
|
"description_width": "",
|
|||
|
"font_size": null,
|
|||
|
"text_color": null
|
|||
|
}
|
|||
|
},
|
|||
|
"d2c41e71a3f441deaed091b620ac5603": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"align_content": null,
|
|||
|
"align_items": null,
|
|||
|
"align_self": null,
|
|||
|
"border_bottom": null,
|
|||
|
"border_left": null,
|
|||
|
"border_right": null,
|
|||
|
"border_top": null,
|
|||
|
"bottom": null,
|
|||
|
"display": null,
|
|||
|
"flex": null,
|
|||
|
"flex_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_column": null,
|
|||
|
"grid_gap": null,
|
|||
|
"grid_row": null,
|
|||
|
"grid_template_areas": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"grid_template_rows": null,
|
|||
|
"height": null,
|
|||
|
"justify_content": null,
|
|||
|
"justify_items": null,
|
|||
|
"left": null,
|
|||
|
"margin": null,
|
|||
|
"max_height": null,
|
|||
|
"max_width": null,
|
|||
|
"min_height": null,
|
|||
|
"min_width": null,
|
|||
|
"object_fit": null,
|
|||
|
"object_position": null,
|
|||
|
"order": null,
|
|||
|
"overflow": null,
|
|||
|
"padding": null,
|
|||
|
"right": null,
|
|||
|
"top": null,
|
|||
|
"visibility": null,
|
|||
|
"width": null
|
|||
|
}
|
|||
|
},
|
|||
|
"e2dc407afcd945c798e30597fddfcb3c": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"align_content": null,
|
|||
|
"align_items": null,
|
|||
|
"align_self": null,
|
|||
|
"border_bottom": null,
|
|||
|
"border_left": null,
|
|||
|
"border_right": null,
|
|||
|
"border_top": null,
|
|||
|
"bottom": null,
|
|||
|
"display": null,
|
|||
|
"flex": null,
|
|||
|
"flex_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_column": null,
|
|||
|
"grid_gap": null,
|
|||
|
"grid_row": null,
|
|||
|
"grid_template_areas": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"grid_template_rows": null,
|
|||
|
"height": null,
|
|||
|
"justify_content": null,
|
|||
|
"justify_items": null,
|
|||
|
"left": null,
|
|||
|
"margin": null,
|
|||
|
"max_height": null,
|
|||
|
"max_width": null,
|
|||
|
"min_height": null,
|
|||
|
"min_width": null,
|
|||
|
"object_fit": null,
|
|||
|
"object_position": null,
|
|||
|
"order": null,
|
|||
|
"overflow": null,
|
|||
|
"padding": null,
|
|||
|
"right": null,
|
|||
|
"top": null,
|
|||
|
"visibility": null,
|
|||
|
"width": null
|
|||
|
}
|
|||
|
},
|
|||
|
"ee44487f58454dacb522b1e084ffb733": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "ProgressStyleModel",
|
|||
|
"state": {
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "ProgressStyleModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "StyleView",
|
|||
|
"bar_color": null,
|
|||
|
"description_width": ""
|
|||
|
}
|
|||
|
},
|
|||
|
"eebf8874618746b39cf4a21a2728dc7f": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_module_version": "2.0.0",
|
|||
|
"model_name": "HTMLModel",
|
|||
|
"state": {
|
|||
|
"_dom_classes": [],
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "2.0.0",
|
|||
|
"_model_name": "HTMLModel",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_view_module_version": "2.0.0",
|
|||
|
"_view_name": "HTMLView",
|
|||
|
"description": "",
|
|||
|
"description_allow_html": false,
|
|||
|
"layout": "IPY_MODEL_d2c41e71a3f441deaed091b620ac5603",
|
|||
|
"placeholder": "",
|
|||
|
"style": "IPY_MODEL_3326b6141a1a4eba9f316df528a9b99a",
|
|||
|
"tabbable": null,
|
|||
|
"tooltip": null,
|
|||
|
"value": " 2.18M/2.18M [00:00<00:00, 9.47MB/s]"
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|