mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-12 16:15:22 +00:00
Simplify KV cache usage (#728)
* Simplify KV cache usage * Swap mark text with ghostwriter
This commit is contained in:
parent
b5bd8d2de2
commit
90c824506c
@ -40,7 +40,7 @@ git clone --depth 1 https://github.com/rasbt/LLMs-from-scratch.git
|
|||||||
|
|
||||||
# Table of Contents
|
# Table of Contents
|
||||||
|
|
||||||
Please note that this `README.md` file is a Markdown (`.md`) file. If you have downloaded this code bundle from the Manning website and are viewing it on your local computer, I recommend using a Markdown editor or previewer for proper viewing. If you haven't installed a Markdown editor yet, [MarkText](https://www.marktext.cc) is a good free option.
|
Please note that this `README.md` file is a Markdown (`.md`) file. If you have downloaded this code bundle from the Manning website and are viewing it on your local computer, I recommend using a Markdown editor or previewer for proper viewing. If you haven't installed a Markdown editor yet, [Ghostwriter](https://ghostwriter.kde.org) is a good free option.
|
||||||
|
|
||||||
You can alternatively view this and other files on GitHub at [https://github.com/rasbt/LLMs-from-scratch](https://github.com/rasbt/LLMs-from-scratch) in your browser, which renders Markdown automatically.
|
You can alternatively view this and other files on GitHub at [https://github.com/rasbt/LLMs-from-scratch](https://github.com/rasbt/LLMs-from-scratch) in your browser, which renders Markdown automatically.
|
||||||
|
|
||||||
|
|||||||
@ -10,20 +10,20 @@ import torch
|
|||||||
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
||||||
model.eval()
|
model.eval()
|
||||||
ctx_len = context_size or model.cfg["context_length"]
|
ctx_len = context_size or model.cfg["context_length"]
|
||||||
cache = KVCache(n_layers=model.cfg["n_layers"]) if use_cache else None
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
cache = KVCache(n_layers=model.cfg["n_layers"])
|
||||||
model.reset_kv_cache()
|
model.reset_kv_cache()
|
||||||
logits = model(idx[:, -ctx_len:], use_cache=True, cache=cache)
|
logits = model(idx[:, -ctx_len:], cache=cache)
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
logits = model(next_idx, use_cache=True, cache=cache)
|
logits = model(next_idx, cache=cache)
|
||||||
else:
|
else:
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
logits = model(idx[:, -ctx_len:], cache=None)
|
||||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
idx = torch.cat([idx, next_idx], dim=1)
|
idx = torch.cat([idx, next_idx], dim=1)
|
||||||
|
|
||||||
|
|||||||
@ -77,12 +77,12 @@ class Llama3Model(nn.Module):
|
|||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.current_pos = 0 # Track current position in KV cache
|
self.current_pos = 0 # Track current position in KV cache
|
||||||
|
|
||||||
def forward(self, in_idx, use_cache=False, cache=None):
|
def forward(self, in_idx, cache=None):
|
||||||
tok_embeds = self.tok_emb(in_idx)
|
tok_embeds = self.tok_emb(in_idx)
|
||||||
x = tok_embeds
|
x = tok_embeds
|
||||||
|
|
||||||
num_tokens = x.shape[1]
|
num_tokens = x.shape[1]
|
||||||
if use_cache:
|
if cache is not None:
|
||||||
pos_start = self.current_pos
|
pos_start = self.current_pos
|
||||||
pos_end = pos_start + num_tokens
|
pos_end = pos_start + num_tokens
|
||||||
self.current_pos = pos_end
|
self.current_pos = pos_end
|
||||||
@ -101,10 +101,9 @@ class Llama3Model(nn.Module):
|
|||||||
for i, block in enumerate(self.trf_blocks):
|
for i, block in enumerate(self.trf_blocks):
|
||||||
blk_cache = cache.get(i) if cache else None
|
blk_cache = cache.get(i) if cache else None
|
||||||
x, new_blk_cache = block(x, mask, self.cos, self.sin,
|
x, new_blk_cache = block(x, mask, self.cos, self.sin,
|
||||||
use_cache=use_cache,
|
|
||||||
start_pos=pos_start,
|
start_pos=pos_start,
|
||||||
cache=blk_cache)
|
cache=blk_cache)
|
||||||
if cache:
|
if cache is not None:
|
||||||
cache.update(i, new_blk_cache)
|
cache.update(i, new_blk_cache)
|
||||||
next_cache.append(new_blk_cache)
|
next_cache.append(new_blk_cache)
|
||||||
|
|
||||||
@ -130,11 +129,11 @@ class TransformerBlock(nn.Module):
|
|||||||
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||||
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||||
|
|
||||||
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
|
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
|
||||||
# Shortcut connection for attention block
|
# Shortcut connection for attention block
|
||||||
shortcut = x
|
shortcut = x
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
|
x, next_cache = self.att(x, mask, cos, sin, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
|
||||||
x = x + shortcut # Add the original input back
|
x = x + shortcut # Add the original input back
|
||||||
|
|
||||||
# Shortcut connection for feed-forward block
|
# Shortcut connection for feed-forward block
|
||||||
@ -180,7 +179,7 @@ class GroupedQueryAttention(nn.Module):
|
|||||||
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
|
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
|
||||||
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
|
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
|
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
|
||||||
b, num_tokens, _ = x.shape
|
b, num_tokens, _ = x.shape
|
||||||
|
|
||||||
# Apply projections
|
# Apply projections
|
||||||
@ -197,18 +196,15 @@ class GroupedQueryAttention(nn.Module):
|
|||||||
queries = apply_rope(queries, cos, sin, offset=start_pos)
|
queries = apply_rope(queries, cos, sin, offset=start_pos)
|
||||||
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
|
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
|
||||||
|
|
||||||
if use_cache:
|
if cache is not None:
|
||||||
if cache is None:
|
|
||||||
keys = keys_new
|
|
||||||
values = values_new
|
|
||||||
else:
|
|
||||||
prev_k, prev_v = cache
|
prev_k, prev_v = cache
|
||||||
keys = torch.cat([prev_k, keys_new], dim=2)
|
keys = torch.cat([prev_k, keys_new], dim=2)
|
||||||
values = torch.cat([prev_v, values_new], dim=2)
|
values = torch.cat([prev_v, values_new], dim=2)
|
||||||
next_cache = (keys, values)
|
next_cache = (keys, values)
|
||||||
else:
|
else:
|
||||||
|
start_pos = 0 # reset RoPE
|
||||||
keys, values = keys_new, values_new
|
keys, values = keys_new, values_new
|
||||||
next_cache = None
|
next_cache = (keys, values)
|
||||||
|
|
||||||
# Expand keys and values to match the number of heads
|
# Expand keys and values to match the number of heads
|
||||||
# Shape: (b, num_heads, num_tokens, head_dim)
|
# Shape: (b, num_heads, num_tokens, head_dim)
|
||||||
@ -226,7 +222,7 @@ class GroupedQueryAttention(nn.Module):
|
|||||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||||
|
|
||||||
# Use the mask to fill attention scores
|
# Use the mask to fill attention scores
|
||||||
attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
|
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
|
||||||
|
|
||||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||||
assert keys.shape[-1] == self.head_dim
|
assert keys.shape[-1] == self.head_dim
|
||||||
@ -286,7 +282,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_c
|
|||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(x, cos, sin, offset=9):
|
def apply_rope(x, cos, sin, offset=0):
|
||||||
# x: (batch_size, num_heads, seq_len, head_dim)
|
# x: (batch_size, num_heads, seq_len, head_dim)
|
||||||
batch_size, num_heads, seq_len, head_dim = x.shape
|
batch_size, num_heads, seq_len, head_dim = x.shape
|
||||||
assert head_dim % 2 == 0, "Head dimension must be even"
|
assert head_dim % 2 == 0, "Head dimension must be even"
|
||||||
|
|||||||
@ -44,13 +44,13 @@ class Qwen3Model(nn.Module):
|
|||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.current_pos = 0 # Track current position in KV cache
|
self.current_pos = 0 # Track current position in KV cache
|
||||||
|
|
||||||
def forward(self, in_idx, use_cache=False, cache=None):
|
def forward(self, in_idx, cache=None):
|
||||||
# Forward pass
|
# Forward pass
|
||||||
tok_embeds = self.tok_emb(in_idx)
|
tok_embeds = self.tok_emb(in_idx)
|
||||||
x = tok_embeds
|
x = tok_embeds
|
||||||
|
|
||||||
num_tokens = x.shape[1]
|
num_tokens = x.shape[1]
|
||||||
if use_cache:
|
if cache is not None:
|
||||||
pos_start = self.current_pos
|
pos_start = self.current_pos
|
||||||
pos_end = pos_start + num_tokens
|
pos_end = pos_start + num_tokens
|
||||||
self.current_pos = pos_end
|
self.current_pos = pos_end
|
||||||
@ -69,10 +69,9 @@ class Qwen3Model(nn.Module):
|
|||||||
for i, block in enumerate(self.trf_blocks):
|
for i, block in enumerate(self.trf_blocks):
|
||||||
blk_cache = cache.get(i) if cache else None
|
blk_cache = cache.get(i) if cache else None
|
||||||
x, new_blk_cache = block(x, mask, self.cos, self.sin,
|
x, new_blk_cache = block(x, mask, self.cos, self.sin,
|
||||||
use_cache=use_cache,
|
|
||||||
start_pos=pos_start,
|
start_pos=pos_start,
|
||||||
cache=blk_cache)
|
cache=blk_cache)
|
||||||
if cache:
|
if cache is not None:
|
||||||
cache.update(i, new_blk_cache)
|
cache.update(i, new_blk_cache)
|
||||||
next_cache.append(new_blk_cache)
|
next_cache.append(new_blk_cache)
|
||||||
|
|
||||||
@ -99,11 +98,11 @@ class TransformerBlock(nn.Module):
|
|||||||
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||||
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
||||||
|
|
||||||
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
|
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
|
||||||
# Shortcut connection for attention block
|
# Shortcut connection for attention block
|
||||||
shortcut = x
|
shortcut = x
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
|
x, next_cache = self.att(x, mask, cos, sin, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
|
||||||
x = x + shortcut # Add the original input back
|
x = x + shortcut # Add the original input back
|
||||||
|
|
||||||
# Shortcut connection for feed-forward block
|
# Shortcut connection for feed-forward block
|
||||||
@ -159,7 +158,7 @@ class GroupedQueryAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.q_norm = self.k_norm = None
|
self.q_norm = self.k_norm = None
|
||||||
|
|
||||||
def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
|
def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
|
||||||
b, num_tokens, _ = x.shape
|
b, num_tokens, _ = x.shape
|
||||||
|
|
||||||
# Apply projections
|
# Apply projections
|
||||||
@ -182,18 +181,15 @@ class GroupedQueryAttention(nn.Module):
|
|||||||
queries = apply_rope(queries, cos, sin, offset=start_pos)
|
queries = apply_rope(queries, cos, sin, offset=start_pos)
|
||||||
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
|
keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
|
||||||
|
|
||||||
if use_cache:
|
if cache is not None:
|
||||||
if cache is None:
|
|
||||||
keys = keys_new
|
|
||||||
values = values_new
|
|
||||||
else:
|
|
||||||
prev_k, prev_v = cache
|
prev_k, prev_v = cache
|
||||||
keys = torch.cat([prev_k, keys_new], dim=2)
|
keys = torch.cat([prev_k, keys_new], dim=2)
|
||||||
values = torch.cat([prev_v, values_new], dim=2)
|
values = torch.cat([prev_v, values_new], dim=2)
|
||||||
next_cache = (keys, values)
|
next_cache = (keys, values)
|
||||||
else:
|
else:
|
||||||
|
start_pos = 0 # reset RoPE
|
||||||
keys, values = keys_new, values_new
|
keys, values = keys_new, values_new
|
||||||
next_cache = None
|
next_cache = (keys, values)
|
||||||
|
|
||||||
# Expand K and V to match number of heads
|
# Expand K and V to match number of heads
|
||||||
keys = keys.repeat_interleave(self.group_size, dim=1)
|
keys = keys.repeat_interleave(self.group_size, dim=1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user