mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-27 15:59:49 +00:00
Add GPT-2 KV cache to pkg (#687)
This commit is contained in:
parent
3be0f3202a
commit
fdc3e1b701
@ -80,8 +80,6 @@ class MultiHeadAttention(nn.Module):
|
||||
keys, values = keys_new, values_new
|
||||
self.ptr_cur = 0 # keep pointer sane if you interleave modes
|
||||
####################################################
|
||||
|
||||
|
||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
|
||||
|
||||
@ -113,7 +113,22 @@ from llms_from_scratch.appendix_d import find_highest_gradient, train_model
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
### GPT-2 KV cache variant (Bonus material)
|
||||
|
||||
```python
|
||||
from llms_from_scratch.kv_cache.gpt2 import GPTModel
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple
|
||||
```
|
||||
|
||||
For more information about KV caching, please see the [KV cache README](../../ch04/03_kv-cache).
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
### Llama 3 (Bonus material)
|
||||
|
||||
```python
|
||||
|
||||
287
pkg/llms_from_scratch/kv_cache/gpt2.py
Normal file
287
pkg/llms_from_scratch/kv_cache/gpt2.py
Normal file
@ -0,0 +1,287 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
#####################################
|
||||
# Chapter 3
|
||||
#####################################
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, max_seq_len=None, window_size=None):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
|
||||
self.d_out = d_out
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
|
||||
|
||||
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
self.max_seq_len = max_seq_len or context_length
|
||||
self.window_size = window_size or self.max_seq_len
|
||||
self.register_buffer("cache_k", None, persistent=False)
|
||||
self.register_buffer("cache_v", None, persistent=False)
|
||||
####################################################
|
||||
|
||||
def forward(self, x, use_cache=False):
|
||||
b, num_tokens, d_in = x.shape
|
||||
|
||||
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
||||
values_new = self.W_value(x)
|
||||
queries = self.W_query(x)
|
||||
|
||||
# We implicitly split the matrix by adding a `num_heads` dimension
|
||||
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
|
||||
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||
|
||||
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
|
||||
keys_new = keys_new.transpose(1, 2)
|
||||
values_new = values_new.transpose(1, 2)
|
||||
queries = queries.transpose(1, 2)
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
if use_cache:
|
||||
if self.cache_k is None or self.cache_k.size(0) != b:
|
||||
self.cache_k = torch.zeros(b, self.num_heads,
|
||||
self.window_size, self.head_dim,
|
||||
device=x.device)
|
||||
self.cache_v = torch.zeros_like(self.cache_k)
|
||||
self.ptr_cur = 0 # pointer to next free slot
|
||||
|
||||
# if incoming chunk would overflow discard oldest tokens
|
||||
if self.ptr_cur + num_tokens > self.window_size:
|
||||
overflow = self.ptr_cur + num_tokens - self.window_size
|
||||
# shift everything left by `overflow` (cheap view-copy)
|
||||
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
|
||||
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
|
||||
self.ptr_cur -= overflow # pointer after shift
|
||||
|
||||
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
|
||||
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
|
||||
self.ptr_cur += num_tokens
|
||||
|
||||
keys = self.cache_k[:, :, :self.ptr_cur, :]
|
||||
values = self.cache_v[:, :, :self.ptr_cur, :]
|
||||
else:
|
||||
keys, values = keys_new, values_new
|
||||
self.ptr_cur = 0 # keep pointer sane if you interleave modes
|
||||
####################################################
|
||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
K = attn_scores.size(-1)
|
||||
|
||||
if num_tokens == K:
|
||||
# No cache → use the pre‑baked triangular mask slice
|
||||
causal_mask = torch.triu(torch.ones(num_tokens, K, device=x.device, dtype=torch.bool), diagonal=1)
|
||||
else:
|
||||
# Cached: need to offset the diagonal by (K − num_tokens)
|
||||
offset = K - num_tokens # number of tokens already in cache before this chunk
|
||||
row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1) # (num_tokens, 1)
|
||||
col_idx = torch.arange(K, device=x.device).unsqueeze(0) # (1, K)
|
||||
causal_mask = row_idx + offset < col_idx # True where j > i+offset
|
||||
####################################################
|
||||
|
||||
# Use the mask to fill attention scores
|
||||
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
|
||||
|
||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||
attn_weights = self.dropout(attn_weights)
|
||||
|
||||
# Shape: (b, num_tokens, num_heads, head_dim)
|
||||
context_vec = (attn_weights @ values).transpose(1, 2)
|
||||
|
||||
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
||||
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
|
||||
context_vec = self.out_proj(context_vec) # optional projection
|
||||
|
||||
return context_vec
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def reset_cache(self):
|
||||
self.cache_k, self.cache_v = None, None
|
||||
####################################################
|
||||
|
||||
|
||||
#####################################
|
||||
# Chapter 4
|
||||
#####################################
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, emb_dim):
|
||||
super().__init__()
|
||||
self.eps = 1e-5
|
||||
self.scale = nn.Parameter(torch.ones(emb_dim))
|
||||
self.shift = nn.Parameter(torch.zeros(emb_dim))
|
||||
|
||||
def forward(self, x):
|
||||
mean = x.mean(dim=-1, keepdim=True)
|
||||
var = x.var(dim=-1, keepdim=True, unbiased=False)
|
||||
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
return self.scale * norm_x + self.shift
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
|
||||
(x + 0.044715 * torch.pow(x, 3))
|
||||
))
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
|
||||
GELU(),
|
||||
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.att = MultiHeadAttention(
|
||||
d_in=cfg["emb_dim"],
|
||||
d_out=cfg["emb_dim"],
|
||||
context_length=cfg["context_length"],
|
||||
num_heads=cfg["n_heads"],
|
||||
dropout=cfg["drop_rate"],
|
||||
qkv_bias=cfg["qkv_bias"],
|
||||
window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = LayerNorm(cfg["emb_dim"])
|
||||
self.norm2 = LayerNorm(cfg["emb_dim"])
|
||||
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
|
||||
|
||||
def forward(self, x, use_cache=False):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
|
||||
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
|
||||
####################################################
|
||||
# NEW
|
||||
x = self.att(x, use_cache=use_cache)
|
||||
####################################################
|
||||
|
||||
x = self.drop_shortcut(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
shortcut = x
|
||||
x = self.norm2(x)
|
||||
x = self.ff(x)
|
||||
x = self.drop_shortcut(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GPTModel(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
|
||||
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
|
||||
self.drop_emb = nn.Dropout(cfg["drop_rate"])
|
||||
|
||||
# self.trf_blocks = nn.Sequential(
|
||||
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
####################################################
|
||||
# NEW
|
||||
self.trf_blocks = nn.ModuleList(
|
||||
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
|
||||
self.ptr_current_pos = 0
|
||||
####################################################
|
||||
|
||||
self.final_norm = LayerNorm(cfg["emb_dim"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
|
||||
|
||||
def forward(self, in_idx, use_cache=False):
|
||||
batch_size, seq_len = in_idx.shape
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
|
||||
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
|
||||
if use_cache:
|
||||
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
||||
self.ptr_current_pos += seq_len
|
||||
else:
|
||||
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
|
||||
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
|
||||
####################################################
|
||||
|
||||
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
|
||||
x = self.drop_emb(x)
|
||||
|
||||
# x = self.trf_blocks(x)
|
||||
####################################################
|
||||
# NEW
|
||||
for blk in self.trf_blocks:
|
||||
x = blk(x, use_cache=use_cache)
|
||||
####################################################
|
||||
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x)
|
||||
return logits
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def reset_kv_cache(self):
|
||||
for blk in self.trf_blocks:
|
||||
blk.att.reset_cache()
|
||||
self.ptr_current_pos = 0
|
||||
####################################################
|
||||
|
||||
|
||||
def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
# idx is (B, T) array of indices in the current context
|
||||
for _ in range(max_new_tokens):
|
||||
|
||||
# Crop current context if it exceeds the supported context size
|
||||
# E.g., if LLM supports only 5 tokens, and the context size is 10
|
||||
# then only the last 5 tokens are used as context
|
||||
idx_cond = idx[:, -context_size:]
|
||||
|
||||
# Get the predictions
|
||||
with torch.no_grad():
|
||||
logits = model(idx_cond)
|
||||
|
||||
# Focus only on the last time step
|
||||
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
# Get the idx of the vocab entry with the highest logits value
|
||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
|
||||
|
||||
# Append sampled index to the running sequence
|
||||
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
|
||||
|
||||
return idx
|
||||
@ -4,7 +4,9 @@
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from llms_from_scratch.ch04 import GPTModel, GPTModelFast
|
||||
from llms_from_scratch.kv_cache.gpt2 import GPTModel as GPTModelKV
|
||||
from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -22,8 +24,16 @@ GPT_CONFIG_124M = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
|
||||
def test_gpt_model_variants(ModelClass):
|
||||
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast, GPTModelKV])
|
||||
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
||||
def test_gpt_model_variants(ModelClass, generate_fn):
|
||||
|
||||
# Skip incompatible combinations
|
||||
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(GPT_CONFIG_124M)
|
||||
model.eval() # disable dropout
|
||||
@ -39,7 +49,7 @@ def test_gpt_model_variants(ModelClass):
|
||||
print("Encoded input text:", encoded)
|
||||
print("encoded_tensor.shape:", encoded_tensor.shape)
|
||||
|
||||
out = generate_text_simple(
|
||||
out = generate_fn(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
max_new_tokens=10,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user