# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). # Source for "Build a Large Language Model From Scratch" # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch from ..qwen3 import Qwen3Tokenizer, download_from_huggingface, load_weights_into_qwen # noqa: F401 import torch import torch.nn as nn # 0.6B model QWEN_CONFIG_06_B = { "vocab_size": 151_936, # Vocabulary size "context_length": 40_960, # Context length that was used to train the model "window_size": None, # Window size for the KV cache; context_length if None "emb_dim": 1024, # Embedding dimension "n_heads": 16, # Number of attention heads "n_layers": 28, # Number of layers "hidden_dim": 3072, # Size of the intermediate dimension in FeedForward "head_dim": 128, # Size of the heads in GQA "qk_norm": True, # Whether to normalize queries and values in GQA "n_kv_groups": 8, # Key-Value groups for grouped-query attention "rope_base": 1_000_000.0, # The base in RoPE's "theta" "dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage } class Qwen3Model(nn.Module): def __init__(self, cfg): super().__init__() # Main model parameters self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin` [TransformerBlock(cfg) for _ in range(cfg["n_layers"])] ) self.final_norm = RMSNorm(cfg["emb_dim"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) # Reusuable utilities if cfg["head_dim"] is None: head_dim = cfg["emb_dim"] // cfg["n_heads"] else: head_dim = cfg["head_dim"] cos, sin = compute_rope_params( head_dim=head_dim, theta_base=cfg["rope_base"], context_length=cfg["context_length"] ) self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) self.cfg = cfg def forward(self, in_idx, use_cache=False): # Forward pass tok_embeds = self.tok_emb(in_idx) x = tok_embeds for block in self.trf_blocks: x = block(x, self.cos, self.sin, use_cache) x = self.final_norm(x) logits = self.out_head(x.to(self.cfg["dtype"])) return logits def reset_kv_cache(self): for blk in self.trf_blocks: blk.att.reset_cache() self.ptr_current_pos = 0 class TransformerBlock(nn.Module): def __init__(self, cfg): super().__init__() self.att = GroupedQueryAttention( d_in=cfg["emb_dim"], num_heads=cfg["n_heads"], head_dim=cfg["head_dim"], num_kv_groups=cfg["n_kv_groups"], qk_norm=cfg["qk_norm"], max_seq_len=cfg["context_length"], dtype=cfg["dtype"] ) self.ff = FeedForward(cfg) self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6) self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6) def forward(self, x, cos, sin, use_cache=False): # Shortcut connection for attention block shortcut = x x = self.norm1(x) x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size] x = x + shortcut # Add the original input back # Shortcut connection for feed-forward block shortcut = x x = self.norm2(x) x = self.ff(x) x = x + shortcut # Add the original input back return x class FeedForward(nn.Module): def __init__(self, cfg): super().__init__() self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False) def forward(self, x): x_fc1 = self.fc1(x) x_fc2 = self.fc2(x) x = nn.functional.silu(x_fc1) * x_fc2 return self.fc3(x) class GroupedQueryAttention(nn.Module): def __init__( self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None, max_seq_len=None, window_size=None ): super().__init__() assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" self.num_heads = num_heads self.num_kv_groups = num_kv_groups self.group_size = num_heads // num_kv_groups if head_dim is None: assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set" head_dim = d_in // num_heads self.head_dim = head_dim self.d_out = num_heads * head_dim self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype) self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype) self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype) self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype) if qk_norm: self.q_norm = RMSNorm(head_dim, eps=1e-6) self.k_norm = RMSNorm(head_dim, eps=1e-6) else: self.q_norm = self.k_norm = None # For optional KV cache self.max_seq_len = max_seq_len self.window_size = window_size or self.max_seq_len self.register_buffer("cache_k", None, persistent=False) self.register_buffer("cache_v", None, persistent=False) self.cache_initialized = False self.ptr = 0 def forward(self, x, cos, sin, use_cache=False): b, num_tokens, _ = x.shape # Apply projections queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim) keys_new = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim) values_new = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim) # Reshape queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) # Optional normalization if self.q_norm: queries = self.q_norm(queries) if self.k_norm: keys_new = self.k_norm(keys_new) # For KV cache pos_start = self.ptr pos_end = pos_start + num_tokens cos_slice = cos[pos_start:pos_end] sin_slice = sin[pos_start:pos_end] # Apply RoPE keys_new = apply_rope(keys_new, cos_slice, sin_slice) queries = apply_rope(queries, cos_slice, sin_slice) # Expand K and V to match number of heads keys_new = keys_new.repeat_interleave(self.group_size, dim=1) values_new = values_new.repeat_interleave(self.group_size, dim=1) if use_cache: if not self.cache_initialized: self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype) self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype) self.ptr = 0 self.cache_initialized = True # In-place update end = self.ptr + num_tokens self.cache_k[:, :, self.ptr:end].copy_(keys_new) self.cache_v[:, :, self.ptr:end].copy_(values_new) keys = self.cache_k[:, :, max(0, end - self.window_size):end] values = self.cache_v[:, :, max(0, end - self.window_size):end] self.ptr = end else: keys, values = keys_new, values_new # Attention attn_scores = queries @ keys.transpose(2, 3) # Create causal mask to fill attention scores T_q = queries.shape[-2] T_k = keys.shape[-2] if not use_cache or T_q > 1: causal_mask = torch.triu( torch.ones((T_q, T_k), device=x.device, dtype=torch.bool), diagonal=1 ) attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf) attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1) context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out) return self.out_proj(context) def reset_cache(self): if self.cache_k is not None: self.cache_k.zero_() self.cache_v.zero_() self.ptr = 0 def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32): assert head_dim % 2 == 0, "Embedding dimension must be even" # Compute the inverse frequencies inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim)) # Generate position indices positions = torch.arange(context_length, dtype=dtype) # Compute the angles angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2) # Expand angles to match the head_dim angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim) # Precompute sine and cosine cos = torch.cos(angles) sin = torch.sin(angles) return cos, sin def apply_rope(x, cos, sin): # x: (batch_size, num_heads, seq_len, head_dim) batch_size, num_heads, seq_len, head_dim = x.shape assert head_dim % 2 == 0, "Head dimension must be even" # Split x into first half and second half x1 = x[..., : head_dim // 2] # First half x2 = x[..., head_dim // 2:] # Second half # Adjust sin and cos shapes cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim) sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) # Apply the rotary transformation rotated = torch.cat((-x2, x1), dim=-1) x_rotated = (x * cos) + (rotated * sin) # It's ok to use lower-precision after applying cos and sin rotation return x_rotated.to(dtype=x.dtype) class RMSNorm(nn.Module): def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True): super().__init__() self.eps = eps self.qwen3_compatible = qwen3_compatible self.scale = nn.Parameter(torch.ones(emb_dim)) self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None def forward(self, x): input_dtype = x.dtype if self.qwen3_compatible: x = x.to(torch.float32) variance = x.pow(2).mean(dim=-1, keepdim=True) norm_x = x * torch.rsqrt(variance + self.eps) norm_x = norm_x * self.scale if self.shift is not None: norm_x = norm_x + self.shift return norm_x.to(input_dtype)