mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-06-26 23:50:03 +00:00
Improve KV cache code for torch.compile (#705)
* Improve KV cache code for torch.compile * cleanup * cleanup
This commit is contained in:
parent
6522be94be
commit
81eda38d3b
@ -27,7 +27,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.register_buffer(
|
||||
"mask",
|
||||
torch.triu(torch.ones(context_length, context_length),diagonal=1),
|
||||
torch.triu(torch.ones(context_length, context_length), diagonal=1),
|
||||
persistent=False
|
||||
)
|
||||
|
||||
|
@ -236,14 +236,14 @@ token_ids = generate_text_simple(
|
||||
)
|
||||
```
|
||||
|
||||
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
|
||||
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage results in even lower memory usage here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements; and KV-cache memory may increase prohibitively for longer contexts lengths).
|
||||
|
||||
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
|
||||
|-------------|-------------------|-----------------|------------|-------------------|
|
||||
| ----------- | ----------------- | --------------- | ---------- | ----------------- |
|
||||
| Llama3Model | Regular | Mac Mini M4 CPU | 1 | - |
|
||||
| Llama3Model | Regular compiled | Mac Mini M4 CPU | - | - |
|
||||
| Llama3Model | KV cache | Mac Mini M4 CPU | 62 | - |
|
||||
| Llama3Model | KV cache compiled | Mac Mini M4 CPU | - | - |
|
||||
| Llama3Model | KV cache | Mac Mini M4 CPU | 68 | - |
|
||||
| Llama3Model | KV cache compiled | Mac Mini M4 CPU | 86 | - |
|
||||
| | | | | |
|
||||
| Llama3Model | Regular | Mac Mini M4 GPU | 15 | - |
|
||||
| Llama3Model | Regular compiled | Mac Mini M4 GPU | - | - |
|
||||
@ -252,7 +252,7 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
|
||||
| | | | | |
|
||||
| Llama3Model | Regular | Nvidia A100 GPU | 42 | 2.91 GB |
|
||||
| Llama3Model | Regular compiled | Nvidia A100 GPU | 170 | 3.12 GB |
|
||||
| Llama3Model | KV cache | Nvidia A100 GPU | 60 | 18.87 GB |
|
||||
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 59 | 19.12 GB |
|
||||
| Llama3Model | KV cache | Nvidia A100 GPU | 58 | 2.87 GB |
|
||||
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 161 | 3.61 GB |
|
||||
|
||||
Note that all settings above have been tested to produce the same text outputs.
|
@ -209,23 +209,23 @@ token_ids = generate_text_simple(
|
||||
)
|
||||
```
|
||||
|
||||
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
|
||||
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage results in even lower memory usage here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements; and KV-cache memory may increase prohibitively for longer contexts lengths).
|
||||
|
||||
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
|
||||
| ---------- | ----------------- | --------------- | ---------- | ----------------- |
|
||||
| Qwen3Model | Regular | Mac Mini M4 CPU | 1 | - |
|
||||
| Qwen3Model | Regular compiled | Mac Mini M4 CPU | 1 | - |
|
||||
| Qwen3Model | KV cache | Mac Mini M4 CPU | 80 | - |
|
||||
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 82 | - |
|
||||
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 137 | - |
|
||||
| | | | | |
|
||||
| Qwen3Model | Regular | Mac Mini M4 GPU | 21 | - |
|
||||
| Qwen3Model | Regular compiled | Mac Mini M4 GPU | Error | - |
|
||||
| Qwen3Model | KV cache | Mac Mini M4 GPU | 32 | - |
|
||||
| Qwen3Model | KV cache | Mac Mini M4 GPU | 28 | - |
|
||||
| Qwen3Model | KV cache compiled | Mac Mini M4 GPU | Error | - |
|
||||
| | | | | |
|
||||
| Qwen3Model | Regular | Nvidia A100 GPU | 25 | 1.49 GB |
|
||||
| Qwen3Model | Regular | Nvidia A100 GPU | 26 | 1.49 GB |
|
||||
| Qwen3Model | Regular compiled | Nvidia A100 GPU | 107 | 1.99 GB |
|
||||
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 10.20 GB |
|
||||
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 24 | 10.61 GB |
|
||||
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 1.47 GB |
|
||||
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 90 | 1.48 GB |
|
||||
|
||||
Note that all settings above have been tested to produce the same text outputs.
|
@ -3,23 +3,24 @@
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from .utils import KVCache
|
||||
import torch
|
||||
|
||||
|
||||
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
||||
model.eval()
|
||||
|
||||
ctx_len = context_size or model.cfg["context_length"]
|
||||
cache = KVCache(n_layers=model.cfg["n_layers"]) if use_cache else None
|
||||
|
||||
with torch.no_grad():
|
||||
if use_cache:
|
||||
model.reset_kv_cache()
|
||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||
logits = model(idx[:, -ctx_len:], use_cache=True, cache=cache)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
logits = model(next_idx, use_cache=True)
|
||||
logits = model(next_idx, use_cache=True, cache=cache)
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||
|
@ -3,6 +3,8 @@
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from .utils import KVCache # noqa: F401
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -11,7 +13,7 @@ import torch.nn as nn
|
||||
# Chapter 3
|
||||
#####################################
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, max_seq_len=None, window_size=None):
|
||||
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
|
||||
@ -25,80 +27,41 @@ class MultiHeadAttention(nn.Module):
|
||||
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
self.max_seq_len = max_seq_len or context_length
|
||||
self.window_size = window_size or self.max_seq_len
|
||||
self.register_buffer("cache_k", None, persistent=False)
|
||||
self.register_buffer("cache_v", None, persistent=False)
|
||||
####################################################
|
||||
|
||||
def forward(self, x, use_cache=False):
|
||||
def forward(self, x, use_cache=False, start_pos=0, cache=None):
|
||||
b, num_tokens, d_in = x.shape
|
||||
|
||||
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
||||
values_new = self.W_value(x)
|
||||
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
||||
values = self.W_value(x)
|
||||
queries = self.W_query(x)
|
||||
|
||||
# We implicitly split the matrix by adding a `num_heads` dimension
|
||||
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
|
||||
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
|
||||
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
|
||||
keys_new = keys_new.transpose(1, 2)
|
||||
values_new = values_new.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2)
|
||||
queries = queries.transpose(1, 2)
|
||||
values = values.transpose(1, 2)
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
if use_cache:
|
||||
if self.cache_k is None or self.cache_k.size(0) != b:
|
||||
self.cache_k = torch.zeros(b, self.num_heads,
|
||||
self.window_size, self.head_dim,
|
||||
device=x.device)
|
||||
self.cache_v = torch.zeros_like(self.cache_k)
|
||||
self.ptr_cur = 0 # pointer to next free slot
|
||||
|
||||
# if incoming chunk would overflow discard oldest tokens
|
||||
if self.ptr_cur + num_tokens > self.window_size:
|
||||
overflow = self.ptr_cur + num_tokens - self.window_size
|
||||
# shift everything left by `overflow` (cheap view-copy)
|
||||
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
|
||||
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
|
||||
self.ptr_cur -= overflow # pointer after shift
|
||||
|
||||
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
|
||||
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
|
||||
self.ptr_cur += num_tokens
|
||||
|
||||
keys = self.cache_k[:, :, :self.ptr_cur, :]
|
||||
values = self.cache_v[:, :, :self.ptr_cur, :]
|
||||
if cache is not None:
|
||||
keys = torch.cat([cache[0], keys], dim=2)
|
||||
values = torch.cat([cache[1], values], dim=2)
|
||||
next_cache = (keys, values)
|
||||
else:
|
||||
keys, values = keys_new, values_new
|
||||
self.ptr_cur = 0 # keep pointer sane if you interleave modes
|
||||
####################################################
|
||||
next_cache = None
|
||||
|
||||
seq_len = keys.size(2)
|
||||
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device), diagonal=1)
|
||||
causal_mask = causal_mask[:, -num_tokens:][None, None, :, :]
|
||||
|
||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
K = attn_scores.size(-1)
|
||||
|
||||
if num_tokens == K:
|
||||
# No cache → use the pre‑baked triangular mask slice
|
||||
causal_mask = torch.triu(torch.ones(num_tokens, K, device=x.device, dtype=torch.bool), diagonal=1)
|
||||
else:
|
||||
# Cached: need to offset the diagonal by (K − num_tokens)
|
||||
offset = K - num_tokens # number of tokens already in cache before this chunk
|
||||
row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1) # (num_tokens, 1)
|
||||
col_idx = torch.arange(K, device=x.device).unsqueeze(0) # (1, K)
|
||||
causal_mask = row_idx + offset < col_idx # True where j > i+offset
|
||||
####################################################
|
||||
|
||||
# Use the mask to fill attention scores
|
||||
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
|
||||
attn_scores.masked_fill_(causal_mask, -torch.inf)
|
||||
|
||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||
attn_weights = self.dropout(attn_weights)
|
||||
@ -110,13 +73,7 @@ class MultiHeadAttention(nn.Module):
|
||||
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
|
||||
context_vec = self.out_proj(context_vec) # optional projection
|
||||
|
||||
return context_vec
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def reset_cache(self):
|
||||
self.cache_k, self.cache_v = None, None
|
||||
####################################################
|
||||
return context_vec, next_cache
|
||||
|
||||
|
||||
#####################################
|
||||
@ -169,25 +126,17 @@ class TransformerBlock(nn.Module):
|
||||
context_length=cfg["context_length"],
|
||||
num_heads=cfg["n_heads"],
|
||||
dropout=cfg["drop_rate"],
|
||||
qkv_bias=cfg["qkv_bias"],
|
||||
window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
|
||||
)
|
||||
qkv_bias=cfg["qkv_bias"])
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = LayerNorm(cfg["emb_dim"])
|
||||
self.norm2 = LayerNorm(cfg["emb_dim"])
|
||||
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
|
||||
|
||||
def forward(self, x, use_cache=False):
|
||||
def forward(self, x, use_cache=False, start_pos=0, cache=None):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
|
||||
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
|
||||
####################################################
|
||||
# NEW
|
||||
x = self.att(x, use_cache=use_cache)
|
||||
####################################################
|
||||
|
||||
x, next_cache = self.att(x, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = self.drop_shortcut(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
@ -198,7 +147,7 @@ class TransformerBlock(nn.Module):
|
||||
x = self.drop_shortcut(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
return x, next_cache
|
||||
|
||||
|
||||
class GPTModel(nn.Module):
|
||||
@ -208,80 +157,34 @@ class GPTModel(nn.Module):
|
||||
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
|
||||
self.drop_emb = nn.Dropout(cfg["drop_rate"])
|
||||
|
||||
# self.trf_blocks = nn.Sequential(
|
||||
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
####################################################
|
||||
# NEW
|
||||
self.trf_blocks = nn.ModuleList(
|
||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
|
||||
self.ptr_current_pos = 0
|
||||
####################################################
|
||||
self.trf_blocks = nn.Sequential(
|
||||
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
|
||||
self.final_norm = LayerNorm(cfg["emb_dim"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
|
||||
self.current_pos = 0
|
||||
|
||||
def forward(self, in_idx, use_cache=False):
|
||||
def forward(self, in_idx, use_cache=False, cache=None):
|
||||
batch_size, seq_len = in_idx.shape
|
||||
pos = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device)
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
|
||||
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
pos_embeds = self.pos_emb(pos)
|
||||
x = self.drop_emb(tok_embeds + pos_embeds)
|
||||
|
||||
if use_cache:
|
||||
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
||||
self.ptr_current_pos += seq_len
|
||||
start_pos = self.current_pos
|
||||
self.current_pos += seq_len
|
||||
else:
|
||||
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
|
||||
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
|
||||
####################################################
|
||||
start_pos = 0
|
||||
|
||||
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
|
||||
x = self.drop_emb(x)
|
||||
|
||||
# x = self.trf_blocks(x)
|
||||
####################################################
|
||||
# NEW
|
||||
for blk in self.trf_blocks:
|
||||
x = blk(x, use_cache=use_cache)
|
||||
####################################################
|
||||
next_cache = []
|
||||
for i, block in enumerate(self.trf_blocks):
|
||||
blk_cache = cache.get(i) if cache else None
|
||||
x, new_cache = block(x, use_cache=use_cache, start_pos=start_pos, cache=blk_cache)
|
||||
if cache:
|
||||
cache.update(i, new_cache)
|
||||
next_cache.append(new_cache)
|
||||
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x)
|
||||
return logits
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def reset_kv_cache(self):
|
||||
for blk in self.trf_blocks:
|
||||
blk.att.reset_cache()
|
||||
self.ptr_current_pos = 0
|
||||
####################################################
|
||||
|
||||
|
||||
def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
# idx is (B, T) array of indices in the current context
|
||||
for _ in range(max_new_tokens):
|
||||
|
||||
# Crop current context if it exceeds the supported context size
|
||||
# E.g., if LLM supports only 5 tokens, and the context size is 10
|
||||
# then only the last 5 tokens are used as context
|
||||
idx_cond = idx[:, -context_size:]
|
||||
|
||||
# Get the predictions
|
||||
with torch.no_grad():
|
||||
logits = model(idx_cond)
|
||||
|
||||
# Focus only on the last time step
|
||||
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
# Get the idx of the vocab entry with the highest logits value
|
||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
|
||||
|
||||
# Append sampled index to the running sequence
|
||||
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
|
||||
|
||||
return idx
|
||||
|
@ -3,15 +3,20 @@
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from ..llama3 import Llama3Tokenizer, ChatFormat, clean_text # noqa: F401
|
||||
from .utils import KVCache # noqa: F401
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
|
||||
LLAMA32_CONFIG_1B = {
|
||||
"vocab_size": 128_256, # Vocabulary size
|
||||
"context_length": 131_072, # Context length that was used to train the model
|
||||
"window_size": None, # Window size for the KV cache; context_length if None
|
||||
"emb_dim": 2048, # Embedding dimension
|
||||
"n_heads": 32, # Number of attention heads
|
||||
"n_layers": 16, # Number of layers
|
||||
@ -30,7 +35,6 @@ LLAMA32_CONFIG_1B = {
|
||||
LLAMA32_CONFIG_3B = {
|
||||
"vocab_size": 128_256, # Vocabulary size
|
||||
"context_length": 131_072, # Context length that was used to train the model
|
||||
"window_size": None, # Window size for the KV cache; context_length if None
|
||||
"emb_dim": 3072, # Embedding dimension
|
||||
"n_heads": 24, # Number of attention heads
|
||||
"n_layers": 28, # Number of layers
|
||||
@ -71,21 +75,45 @@ class Llama3Model(nn.Module):
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
self.cfg = cfg
|
||||
self.current_pos = 0 # Track current position in KV cache
|
||||
|
||||
def forward(self, in_idx, use_cache=False):
|
||||
def forward(self, in_idx, use_cache=False, cache=None):
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
for block in self.trf_blocks:
|
||||
x = block(x, self.cos, self.sin, use_cache)
|
||||
num_tokens = x.shape[1]
|
||||
if use_cache:
|
||||
pos_start = self.current_pos
|
||||
pos_end = pos_start + num_tokens
|
||||
self.current_pos = pos_end
|
||||
mask = torch.triu(
|
||||
torch.ones(pos_end, pos_end, device=x.device, dtype=torch.bool), diagonal=1
|
||||
)[pos_start:pos_end, :pos_end]
|
||||
else:
|
||||
pos_start = 0 # Not strictly necessary but helps torch.compile
|
||||
mask = torch.triu(
|
||||
torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1
|
||||
)
|
||||
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
|
||||
mask = mask[None, None, :, :]
|
||||
|
||||
next_cache = []
|
||||
for i, block in enumerate(self.trf_blocks):
|
||||
blk_cache = cache.get(i) if cache else None
|
||||
x, new_blk_cache = block(x, mask, self.cos, self.sin,
|
||||
use_cache=use_cache,
|
||||
start_pos=pos_start,
|
||||
cache=blk_cache)
|
||||
if cache:
|
||||
cache.update(i, new_blk_cache)
|
||||
next_cache.append(new_blk_cache)
|
||||
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
|
||||
def reset_kv_cache(self):
|
||||
for blk in self.trf_blocks:
|
||||
blk.att.reset_cache()
|
||||
self.ptr_current_pos = 0
|
||||
self.current_pos = 0
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
@ -96,18 +124,17 @@ class TransformerBlock(nn.Module):
|
||||
d_out=cfg["emb_dim"],
|
||||
num_heads=cfg["n_heads"],
|
||||
num_kv_groups=cfg["n_kv_groups"],
|
||||
max_seq_len=cfg["context_length"],
|
||||
dtype=cfg["dtype"]
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
|
||||
def forward(self, x, cos, sin, use_cache=False):
|
||||
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size]
|
||||
x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
@ -116,7 +143,7 @@ class TransformerBlock(nn.Module):
|
||||
x = self.ff(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
return x, next_cache
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@ -135,7 +162,7 @@ class FeedForward(nn.Module):
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(
|
||||
self, d_in, d_out, num_heads, num_kv_groups, dtype=None, max_seq_len=None, window_size=None
|
||||
self, d_in, d_out, num_heads, num_kv_groups, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
@ -153,45 +180,40 @@ class GroupedQueryAttention(nn.Module):
|
||||
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
|
||||
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
|
||||
|
||||
# For optional KV cache
|
||||
self.max_seq_len = max_seq_len
|
||||
self.window_size = window_size or self.max_seq_len
|
||||
self.register_buffer("cache_k", None, persistent=False)
|
||||
self.register_buffer("cache_v", None, persistent=False)
|
||||
self.cache_initialized = False
|
||||
self.ptr = 0
|
||||
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
|
||||
b, num_tokens, _ = x.shape
|
||||
|
||||
def forward(self, x, cos, sin, use_cache=False):
|
||||
b, num_tokens, d_in = x.shape
|
||||
# Apply projections
|
||||
queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
|
||||
keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
|
||||
queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
|
||||
keys_new = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
|
||||
values_new = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
|
||||
|
||||
# Reshape queries, keys, and values
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
|
||||
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
|
||||
|
||||
# Transpose keys, values, and queries
|
||||
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
keys_new = keys_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
|
||||
values_new = values_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
|
||||
|
||||
# For KV cache
|
||||
pos_start = self.ptr
|
||||
pos_end = pos_start + num_tokens
|
||||
cos_slice = cos[pos_start:pos_end]
|
||||
sin_slice = sin[pos_start:pos_end]
|
||||
# Reshape
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply RoPE
|
||||
keys_new = apply_rope(keys_new, cos_slice, sin_slice)
|
||||
queries = apply_rope(queries, cos_slice, sin_slice)
|
||||
queries = apply_rope(queries, cos, sin, offset=start_pos)
|
||||
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
|
||||
|
||||
if use_cache:
|
||||
if cache is None:
|
||||
keys = keys_new
|
||||
values = values_new
|
||||
else:
|
||||
prev_k, prev_v = cache
|
||||
keys = torch.cat([prev_k, keys_new], dim=2)
|
||||
values = torch.cat([prev_v, values_new], dim=2)
|
||||
next_cache = (keys, values)
|
||||
else:
|
||||
keys, values = keys_new, values_new
|
||||
next_cache = None
|
||||
|
||||
# Expand keys and values to match the number of heads
|
||||
# Shape: (b, num_heads, num_tokens, head_dim)
|
||||
keys_new = keys_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
values_new = values_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
|
||||
# For example, before repeat_interleave along dim=1 (query groups):
|
||||
# [K1, K2]
|
||||
# After repeat_interleave (each query group is repeated group_size times):
|
||||
@ -199,38 +221,12 @@ class GroupedQueryAttention(nn.Module):
|
||||
# If we used regular repeat instead of repeat_interleave, we'd get:
|
||||
# [K1, K2, K1, K2]
|
||||
|
||||
if use_cache:
|
||||
if not self.cache_initialized:
|
||||
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
|
||||
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
|
||||
self.ptr = 0
|
||||
self.cache_initialized = True
|
||||
|
||||
# In-place update
|
||||
end = self.ptr + num_tokens
|
||||
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
|
||||
self.cache_v[:, :, self.ptr:end].copy_(values_new)
|
||||
|
||||
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
|
||||
values = self.cache_v[:, :, max(0, end - self.window_size):end]
|
||||
self.ptr = end
|
||||
else:
|
||||
keys, values = keys_new, values_new
|
||||
|
||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||
# Shape: (b, num_heads, num_tokens, num_tokens)
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
|
||||
# Create causal mask to fill attention scores
|
||||
T_q = queries.shape[-2]
|
||||
T_k = keys.shape[-2]
|
||||
|
||||
if not use_cache or T_q > 1:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
|
||||
diagonal=1
|
||||
)
|
||||
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
|
||||
# Use the mask to fill attention scores
|
||||
attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
|
||||
|
||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||
assert keys.shape[-1] == self.head_dim
|
||||
@ -242,13 +238,7 @@ class GroupedQueryAttention(nn.Module):
|
||||
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
|
||||
context_vec = self.out_proj(context_vec) # optional projection
|
||||
|
||||
return context_vec
|
||||
|
||||
def reset_cache(self):
|
||||
if self.cache_k is not None:
|
||||
self.cache_k.zero_()
|
||||
self.cache_v.zero_()
|
||||
self.ptr = 0
|
||||
return context_vec, next_cache
|
||||
|
||||
|
||||
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
|
||||
@ -296,7 +286,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_c
|
||||
return cos, sin
|
||||
|
||||
|
||||
def apply_rope(x, cos, sin):
|
||||
def apply_rope(x, cos, sin, offset=9):
|
||||
# x: (batch_size, num_heads, seq_len, head_dim)
|
||||
batch_size, num_heads, seq_len, head_dim = x.shape
|
||||
assert head_dim % 2 == 0, "Head dimension must be even"
|
||||
@ -306,8 +296,8 @@ def apply_rope(x, cos, sin):
|
||||
x2 = x[..., head_dim // 2:] # Second half
|
||||
|
||||
# Adjust sin and cos shapes
|
||||
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
|
||||
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
||||
cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
|
||||
sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Apply the rotary transformation
|
||||
rotated = torch.cat((-x2, x1), dim=-1)
|
||||
@ -315,3 +305,231 @@ def apply_rope(x, cos, sin):
|
||||
|
||||
# It's ok to use lower-precision after applying cos and sin rotation
|
||||
return x_rotated.to(dtype=x.dtype)
|
||||
|
||||
|
||||
##########################################
|
||||
# Tokenizer
|
||||
##########################################
|
||||
|
||||
|
||||
class Llama3Tokenizer:
|
||||
"""Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
|
||||
def __init__(self, model_path):
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError(model_path)
|
||||
|
||||
mergeable = load_tiktoken_bpe(model_path)
|
||||
|
||||
# hard-coded from Meta's tokenizer.json
|
||||
self.special = {
|
||||
"<|begin_of_text|>": 128000,
|
||||
"<|end_of_text|>": 128001,
|
||||
"<|start_header_id|>": 128006,
|
||||
"<|end_header_id|>": 128007,
|
||||
"<|eot_id|>": 128009,
|
||||
}
|
||||
self.special.update({f"<|reserved_{i}|>": 128002 + i
|
||||
for i in range(256)
|
||||
if 128002 + i not in self.special.values()})
|
||||
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
|
||||
r"|[^\r\n\p{L}\p{N}]?\p{L}+"
|
||||
r"|\p{N}{1,3}"
|
||||
r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
|
||||
r"|\s*[\r\n]+"
|
||||
r"|\s+(?!\S)"
|
||||
r"|\s+",
|
||||
mergeable_ranks=mergeable,
|
||||
special_tokens=self.special,
|
||||
)
|
||||
|
||||
def encode(self, text, bos=False, eos=False, **kwargs):
|
||||
ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
|
||||
+ self.model.encode(text)
|
||||
if eos:
|
||||
ids.append(self.special["<|end_of_text|>"])
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
return self.model.decode(ids)
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
|
||||
def __init__(self, tokenizer: Llama3Tokenizer, *,
|
||||
default_system="You are a helpful assistant."):
|
||||
self.tok = tokenizer
|
||||
self.default_system = default_system
|
||||
|
||||
def _header(self, role):
|
||||
"""Encode <|start_header_id|>role<|end_header_id|>\n\n"""
|
||||
return (
|
||||
[self.tok.special["<|start_header_id|>"]]
|
||||
+ self.tok.encode(role)
|
||||
+ [self.tok.special["<|end_header_id|>"]]
|
||||
+ self.tok.encode("\n\n")
|
||||
)
|
||||
|
||||
def encode(self, user_message, system_message=None, allowed_special=None):
|
||||
sys_msg = system_message if system_message is not None else self.default_system
|
||||
|
||||
ids = [self.tok.special["<|begin_of_text|>"]]
|
||||
|
||||
# system
|
||||
ids += self._header("system")
|
||||
ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
|
||||
ids += [self.tok.special["<|eot_id|>"]]
|
||||
|
||||
# user
|
||||
ids += self._header("user")
|
||||
ids += self.tok.encode(user_message)
|
||||
ids += [self.tok.special["<|eot_id|>"]]
|
||||
|
||||
# assistant header (no content yet)
|
||||
ids += self._header("assistant")
|
||||
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
return self.tok.decode(ids)
|
||||
|
||||
|
||||
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
|
||||
# Find the index of the first occurrence of "<|end_header_id|>"
|
||||
index = text.find(header_end)
|
||||
|
||||
if index != -1:
|
||||
# Return the substring starting after "<|end_header_id|>"
|
||||
return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace
|
||||
else:
|
||||
# If the token is not found, return the original text
|
||||
return text
|
||||
|
||||
|
||||
######################################################################
|
||||
# Llama 3 fast (alternative code geared towards efficiency)
|
||||
######################################################################
|
||||
|
||||
class GroupedQueryAttentionFast(nn.Module):
|
||||
"""
|
||||
Drop-in replacement for GroupedQueryAttention but using PyTorch's
|
||||
scaled_dot_product_attention, which uses FlashAttention if run
|
||||
on an Ampere GPU (like A100) or newer and uses float16/bfloat16 or lower.
|
||||
"""
|
||||
def __init__(self, d_in, d_out, num_heads, num_kv_groups, dtype=None):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
|
||||
|
||||
self.d_out = d_out
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_out // num_heads
|
||||
self.num_kv_groups = num_kv_groups
|
||||
self.group_size = num_heads // num_kv_groups
|
||||
|
||||
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
|
||||
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
|
||||
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
|
||||
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
|
||||
|
||||
def forward(self, x, cos, sin):
|
||||
b, num_tokens, _ = x.shape
|
||||
|
||||
# Project to queries, keys, values
|
||||
q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply Rotary Positional Embedding
|
||||
q = apply_rope(q, cos, sin)
|
||||
k = apply_rope(k, cos, sin)
|
||||
|
||||
# Expand key/value groups to full head count
|
||||
k = k.repeat_interleave(self.group_size, dim=1)
|
||||
v = v.repeat_interleave(self.group_size, dim=1)
|
||||
|
||||
# Efficient scaled dot-product attention
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
is_causal=True # Enables Flash/FlexAttention kernels
|
||||
)
|
||||
|
||||
# Combine heads and project
|
||||
attn_output = attn_output.transpose(1, 2).reshape(b, num_tokens, self.d_out)
|
||||
return self.out_proj(attn_output)
|
||||
|
||||
|
||||
class TransformerBlockFast(nn.Module):
|
||||
"""
|
||||
Same as original TransformerBlock but uses
|
||||
GroupedQueryAttentionFast instead of GroupedQueryAttention.
|
||||
"""
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.att = GroupedQueryAttentionFast(
|
||||
d_in=cfg["emb_dim"],
|
||||
d_out=cfg["emb_dim"],
|
||||
num_heads=cfg["n_heads"],
|
||||
num_kv_groups=cfg["n_kv_groups"],
|
||||
dtype=cfg["dtype"]
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
|
||||
def forward(self, x, cos, sin):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.att(x, cos, sin) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
shortcut = x
|
||||
x = self.norm2(x)
|
||||
x = self.ff(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Llama3ModelFast(nn.Module):
|
||||
"""
|
||||
Same as original Llama3Model but uses TransformerBlockFast
|
||||
instead of TransformerBlock, which in turn uses
|
||||
GroupedQueryAttentionFast instead of GroupedQueryAttention.
|
||||
"""
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
# Main model parameters
|
||||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
|
||||
|
||||
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, cos, sin`
|
||||
[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])]
|
||||
)
|
||||
|
||||
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||
theta_base=cfg["rope_base"],
|
||||
context_length=cfg["context_length"],
|
||||
freq_config=cfg["rope_freq"]
|
||||
)
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, in_idx):
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
for block in self.trf_blocks:
|
||||
x = block(x, self.cos, self.sin)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
|
@ -3,7 +3,11 @@
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from ..qwen3 import Qwen3Tokenizer, download_from_huggingface, load_weights_into_qwen # noqa: F401
|
||||
from .utils import KVCache # noqa: F401
|
||||
|
||||
import os
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -12,7 +16,6 @@ import torch.nn as nn
|
||||
QWEN_CONFIG_06_B = {
|
||||
"vocab_size": 151_936, # Vocabulary size
|
||||
"context_length": 40_960, # Context length that was used to train the model
|
||||
"window_size": None, # Window size for the KV cache; context_length if None
|
||||
"emb_dim": 1024, # Embedding dimension
|
||||
"n_heads": 16, # Number of attention heads
|
||||
"n_layers": 28, # Number of layers
|
||||
@ -51,22 +54,46 @@ class Qwen3Model(nn.Module):
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
self.cfg = cfg
|
||||
self.current_pos = 0 # Track current position in KV cache
|
||||
|
||||
def forward(self, in_idx, use_cache=False):
|
||||
def forward(self, in_idx, use_cache=False, cache=None):
|
||||
# Forward pass
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
for block in self.trf_blocks:
|
||||
x = block(x, self.cos, self.sin, use_cache)
|
||||
num_tokens = x.shape[1]
|
||||
if use_cache:
|
||||
pos_start = self.current_pos
|
||||
pos_end = pos_start + num_tokens
|
||||
self.current_pos = pos_end
|
||||
mask = torch.triu(
|
||||
torch.ones(pos_end, pos_end, device=x.device, dtype=torch.bool), diagonal=1
|
||||
)[pos_start:pos_end, :pos_end]
|
||||
else:
|
||||
pos_start = 0 # Not strictly necessary but helps torch.compile
|
||||
mask = torch.triu(
|
||||
torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1
|
||||
)
|
||||
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
|
||||
mask = mask[None, None, :, :]
|
||||
|
||||
next_cache = []
|
||||
for i, block in enumerate(self.trf_blocks):
|
||||
blk_cache = cache.get(i) if cache else None
|
||||
x, new_blk_cache = block(x, mask, self.cos, self.sin,
|
||||
use_cache=use_cache,
|
||||
start_pos=pos_start,
|
||||
cache=blk_cache)
|
||||
if cache:
|
||||
cache.update(i, new_blk_cache)
|
||||
next_cache.append(new_blk_cache)
|
||||
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
|
||||
def reset_kv_cache(self):
|
||||
for blk in self.trf_blocks:
|
||||
blk.att.reset_cache()
|
||||
self.ptr_current_pos = 0
|
||||
self.current_pos = 0
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
@ -78,18 +105,17 @@ class TransformerBlock(nn.Module):
|
||||
head_dim=cfg["head_dim"],
|
||||
num_kv_groups=cfg["n_kv_groups"],
|
||||
qk_norm=cfg["qk_norm"],
|
||||
max_seq_len=cfg["context_length"],
|
||||
dtype=cfg["dtype"]
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||
|
||||
def forward(self, x, cos, sin, use_cache=False):
|
||||
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size]
|
||||
x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
@ -98,7 +124,7 @@ class TransformerBlock(nn.Module):
|
||||
x = self.ff(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
return x, next_cache
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@ -117,8 +143,7 @@ class FeedForward(nn.Module):
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(
|
||||
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None,
|
||||
max_seq_len=None, window_size=None
|
||||
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
|
||||
@ -146,26 +171,18 @@ class GroupedQueryAttention(nn.Module):
|
||||
else:
|
||||
self.q_norm = self.k_norm = None
|
||||
|
||||
# For optional KV cache
|
||||
self.max_seq_len = max_seq_len
|
||||
self.window_size = window_size or self.max_seq_len
|
||||
self.register_buffer("cache_k", None, persistent=False)
|
||||
self.register_buffer("cache_v", None, persistent=False)
|
||||
self.cache_initialized = False
|
||||
self.ptr = 0
|
||||
|
||||
def forward(self, x, cos, sin, use_cache=False):
|
||||
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
|
||||
b, num_tokens, _ = x.shape
|
||||
|
||||
# Apply projections
|
||||
queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
|
||||
keys_new = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
values_new = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
|
||||
|
||||
# Reshape
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Optional normalization
|
||||
if self.q_norm:
|
||||
@ -173,62 +190,34 @@ class GroupedQueryAttention(nn.Module):
|
||||
if self.k_norm:
|
||||
keys_new = self.k_norm(keys_new)
|
||||
|
||||
# For KV cache
|
||||
pos_start = self.ptr
|
||||
pos_end = pos_start + num_tokens
|
||||
cos_slice = cos[pos_start:pos_end]
|
||||
sin_slice = sin[pos_start:pos_end]
|
||||
|
||||
# Apply RoPE
|
||||
keys_new = apply_rope(keys_new, cos_slice, sin_slice)
|
||||
queries = apply_rope(queries, cos_slice, sin_slice)
|
||||
|
||||
# Expand K and V to match number of heads
|
||||
keys_new = keys_new.repeat_interleave(self.group_size, dim=1)
|
||||
values_new = values_new.repeat_interleave(self.group_size, dim=1)
|
||||
queries = apply_rope(queries, cos, sin, offset=start_pos)
|
||||
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
|
||||
|
||||
if use_cache:
|
||||
if not self.cache_initialized:
|
||||
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
|
||||
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
|
||||
self.ptr = 0
|
||||
self.cache_initialized = True
|
||||
|
||||
# In-place update
|
||||
end = self.ptr + num_tokens
|
||||
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
|
||||
self.cache_v[:, :, self.ptr:end].copy_(values_new)
|
||||
|
||||
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
|
||||
values = self.cache_v[:, :, max(0, end - self.window_size):end]
|
||||
self.ptr = end
|
||||
if cache is None:
|
||||
keys = keys_new
|
||||
values = values_new
|
||||
else:
|
||||
prev_k, prev_v = cache
|
||||
keys = torch.cat([prev_k, keys_new], dim=2)
|
||||
values = torch.cat([prev_v, values_new], dim=2)
|
||||
next_cache = (keys, values)
|
||||
else:
|
||||
keys, values = keys_new, values_new
|
||||
next_cache = None
|
||||
|
||||
# Expand K and V to match number of heads
|
||||
keys = keys.repeat_interleave(self.group_size, dim=1)
|
||||
values = values.repeat_interleave(self.group_size, dim=1)
|
||||
|
||||
# Attention
|
||||
attn_scores = queries @ keys.transpose(2, 3)
|
||||
|
||||
# Create causal mask to fill attention scores
|
||||
T_q = queries.shape[-2]
|
||||
T_k = keys.shape[-2]
|
||||
|
||||
if not use_cache or T_q > 1:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
|
||||
diagonal=1
|
||||
)
|
||||
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
|
||||
|
||||
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
|
||||
attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
|
||||
|
||||
context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
|
||||
return self.out_proj(context)
|
||||
|
||||
def reset_cache(self):
|
||||
if self.cache_k is not None:
|
||||
self.cache_k.zero_()
|
||||
self.cache_v.zero_()
|
||||
self.ptr = 0
|
||||
return self.out_proj(context), next_cache
|
||||
|
||||
|
||||
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
|
||||
@ -253,7 +242,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=
|
||||
return cos, sin
|
||||
|
||||
|
||||
def apply_rope(x, cos, sin):
|
||||
def apply_rope(x, cos, sin, offset=0):
|
||||
# x: (batch_size, num_heads, seq_len, head_dim)
|
||||
batch_size, num_heads, seq_len, head_dim = x.shape
|
||||
assert head_dim % 2 == 0, "Head dimension must be even"
|
||||
@ -263,8 +252,8 @@ def apply_rope(x, cos, sin):
|
||||
x2 = x[..., head_dim // 2:] # Second half
|
||||
|
||||
# Adjust sin and cos shapes
|
||||
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
|
||||
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
||||
cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
|
||||
sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Apply the rotary transformation
|
||||
rotated = torch.cat((-x2, x1), dim=-1)
|
||||
@ -297,3 +286,149 @@ class RMSNorm(nn.Module):
|
||||
|
||||
return norm_x.to(input_dtype)
|
||||
|
||||
|
||||
def load_weights_into_qwen(model, param_config, params):
|
||||
def assign(left, right, tensor_name="unknown"):
|
||||
if left.shape != right.shape:
|
||||
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
|
||||
return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
|
||||
|
||||
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
for l in range(param_config["n_layers"]):
|
||||
block = model.trf_blocks[l]
|
||||
att = block.att
|
||||
|
||||
# Q, K, V projections
|
||||
att.W_query.weight = assign(
|
||||
att.W_query.weight,
|
||||
params[f"model.layers.{l}.self_attn.q_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.q_proj.weight"
|
||||
)
|
||||
att.W_key.weight = assign(
|
||||
att.W_key.weight,
|
||||
params[f"model.layers.{l}.self_attn.k_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.k_proj.weight"
|
||||
)
|
||||
att.W_value.weight = assign(
|
||||
att.W_value.weight,
|
||||
params[f"model.layers.{l}.self_attn.v_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.v_proj.weight"
|
||||
)
|
||||
|
||||
# Output projection
|
||||
att.out_proj.weight = assign(
|
||||
att.out_proj.weight,
|
||||
params[f"model.layers.{l}.self_attn.o_proj.weight"],
|
||||
f"model.layers.{l}.self_attn.o_proj.weight"
|
||||
)
|
||||
|
||||
# QK norms
|
||||
if hasattr(att, "q_norm") and att.q_norm is not None:
|
||||
att.q_norm.scale = assign(
|
||||
att.q_norm.scale,
|
||||
params[f"model.layers.{l}.self_attn.q_norm.weight"],
|
||||
f"model.layers.{l}.self_attn.q_norm.weight"
|
||||
)
|
||||
if hasattr(att, "k_norm") and att.k_norm is not None:
|
||||
att.k_norm.scale = assign(
|
||||
att.k_norm.scale,
|
||||
params[f"model.layers.{l}.self_attn.k_norm.weight"],
|
||||
f"model.layers.{l}.self_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
# Attention layernorm
|
||||
block.norm1.scale = assign(
|
||||
block.norm1.scale,
|
||||
params[f"model.layers.{l}.input_layernorm.weight"],
|
||||
f"model.layers.{l}.input_layernorm.weight"
|
||||
)
|
||||
|
||||
# Feedforward weights
|
||||
block.ff.fc1.weight = assign(
|
||||
block.ff.fc1.weight,
|
||||
params[f"model.layers.{l}.mlp.gate_proj.weight"],
|
||||
f"model.layers.{l}.mlp.gate_proj.weight"
|
||||
)
|
||||
block.ff.fc2.weight = assign(
|
||||
block.ff.fc2.weight,
|
||||
params[f"model.layers.{l}.mlp.up_proj.weight"],
|
||||
f"model.layers.{l}.mlp.up_proj.weight"
|
||||
)
|
||||
block.ff.fc3.weight = assign(
|
||||
block.ff.fc3.weight,
|
||||
params[f"model.layers.{l}.mlp.down_proj.weight"],
|
||||
f"model.layers.{l}.mlp.down_proj.weight"
|
||||
)
|
||||
block.norm2.scale = assign(
|
||||
block.norm2.scale,
|
||||
params[f"model.layers.{l}.post_attention_layernorm.weight"],
|
||||
f"model.layers.{l}.post_attention_layernorm.weight"
|
||||
)
|
||||
|
||||
# Final normalization and output head
|
||||
model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
|
||||
|
||||
# Model uses weight tying, hence we reuse the embedding layer weights here
|
||||
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
||||
|
||||
|
||||
class Qwen3Tokenizer():
|
||||
def __init__(self, tokenizer_file_path="tokenizer.json",
|
||||
repo_id=None, add_generation_prompt=False, add_thinking=False):
|
||||
from tokenizers import Tokenizer
|
||||
self.tokenizer_file_path = tokenizer_file_path
|
||||
|
||||
if add_generation_prompt != add_thinking:
|
||||
raise ValueError(
|
||||
"Only add_generation_prompt==add_thinking settings are currently supported"
|
||||
)
|
||||
|
||||
self.add_generation_prompt = add_generation_prompt
|
||||
self.add_thinking = add_thinking
|
||||
|
||||
tokenizer_file_path_obj = Path(tokenizer_file_path)
|
||||
if not tokenizer_file_path_obj.is_file() and repo_id is not None:
|
||||
_ = download_from_huggingface(
|
||||
repo_id=repo_id,
|
||||
filename=str(tokenizer_file_path_obj.name),
|
||||
local_dir=str(tokenizer_file_path_obj.parent.name)
|
||||
)
|
||||
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
|
||||
|
||||
def encode(self, prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
formatted_prompt = self.format_qwen_chat(
|
||||
messages,
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
add_thinking=self.add_thinking
|
||||
)
|
||||
return self.tokenizer.encode(formatted_prompt).ids
|
||||
|
||||
def decode(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
||||
|
||||
@staticmethod
|
||||
def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
|
||||
prompt = ""
|
||||
for msg in messages:
|
||||
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
|
||||
if add_generation_prompt:
|
||||
prompt += "<|im_start|>assistant"
|
||||
if not add_thinking:
|
||||
prompt += "<|think>\n\n<|/think>\n\n"
|
||||
else:
|
||||
prompt += "\n"
|
||||
return prompt
|
||||
|
||||
|
||||
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
|
||||
base_url = "https://huggingface.co"
|
||||
url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
|
||||
Path(local_dir).mkdir(parents=True, exist_ok=True)
|
||||
dest_path = os.path.join(local_dir, filename)
|
||||
print(f"Downloading {url} to {dest_path}...")
|
||||
urllib.request.urlretrieve(url, dest_path)
|
||||
return dest_path
|
||||
|
21
pkg/llms_from_scratch/kv_cache/utils.py
Normal file
21
pkg/llms_from_scratch/kv_cache/utils.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
class KVCache:
|
||||
def __init__(self, n_layers):
|
||||
self.cache = [None] * n_layers
|
||||
|
||||
def get(self, layer_idx):
|
||||
return self.cache[layer_idx]
|
||||
|
||||
def update(self, layer_idx, value):
|
||||
self.cache[layer_idx] = value
|
||||
|
||||
def get_all(self):
|
||||
return self.cache
|
||||
|
||||
def reset(self):
|
||||
for i in range(len(self.cache)):
|
||||
self.cache[i] = None
|
Loading…
x
Reference in New Issue
Block a user