diff --git a/README.md b/README.md index 877727c..3ba5771 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ git clone --depth 1 https://github.com/rasbt/LLMs-from-scratch.git # Table of Contents -Please note that this `README.md` file is a Markdown (`.md`) file. If you have downloaded this code bundle from the Manning website and are viewing it on your local computer, I recommend using a Markdown editor or previewer for proper viewing. If you haven't installed a Markdown editor yet, [MarkText](https://www.marktext.cc) is a good free option. +Please note that this `README.md` file is a Markdown (`.md`) file. If you have downloaded this code bundle from the Manning website and are viewing it on your local computer, I recommend using a Markdown editor or previewer for proper viewing. If you haven't installed a Markdown editor yet, [Ghostwriter](https://ghostwriter.kde.org) is a good free option. You can alternatively view this and other files on GitHub at [https://github.com/rasbt/LLMs-from-scratch](https://github.com/rasbt/LLMs-from-scratch) in your browser, which renders Markdown automatically. diff --git a/pkg/llms_from_scratch/kv_cache/generate.py b/pkg/llms_from_scratch/kv_cache/generate.py index 03e1282..12ee0c6 100644 --- a/pkg/llms_from_scratch/kv_cache/generate.py +++ b/pkg/llms_from_scratch/kv_cache/generate.py @@ -10,20 +10,20 @@ import torch def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True): model.eval() ctx_len = context_size or model.cfg["context_length"] - cache = KVCache(n_layers=model.cfg["n_layers"]) if use_cache else None with torch.no_grad(): if use_cache: + cache = KVCache(n_layers=model.cfg["n_layers"]) model.reset_kv_cache() - logits = model(idx[:, -ctx_len:], use_cache=True, cache=cache) + logits = model(idx[:, -ctx_len:], cache=cache) for _ in range(max_new_tokens): next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) idx = torch.cat([idx, next_idx], dim=1) - logits = model(next_idx, use_cache=True, cache=cache) + logits = model(next_idx, cache=cache) else: for _ in range(max_new_tokens): - logits = model(idx[:, -ctx_len:], use_cache=False) + logits = model(idx[:, -ctx_len:], cache=None) next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) idx = torch.cat([idx, next_idx], dim=1) diff --git a/pkg/llms_from_scratch/kv_cache/llama3.py b/pkg/llms_from_scratch/kv_cache/llama3.py index cafb785..70258d0 100644 --- a/pkg/llms_from_scratch/kv_cache/llama3.py +++ b/pkg/llms_from_scratch/kv_cache/llama3.py @@ -77,12 +77,12 @@ class Llama3Model(nn.Module): self.cfg = cfg self.current_pos = 0 # Track current position in KV cache - def forward(self, in_idx, use_cache=False, cache=None): + def forward(self, in_idx, cache=None): tok_embeds = self.tok_emb(in_idx) x = tok_embeds num_tokens = x.shape[1] - if use_cache: + if cache is not None: pos_start = self.current_pos pos_end = pos_start + num_tokens self.current_pos = pos_end @@ -101,10 +101,9 @@ class Llama3Model(nn.Module): for i, block in enumerate(self.trf_blocks): blk_cache = cache.get(i) if cache else None x, new_blk_cache = block(x, mask, self.cos, self.sin, - use_cache=use_cache, start_pos=pos_start, cache=blk_cache) - if cache: + if cache is not None: cache.update(i, new_blk_cache) next_cache.append(new_blk_cache) @@ -130,11 +129,11 @@ class TransformerBlock(nn.Module): self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) - def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None): + def forward(self, x, mask, cos, sin, start_pos=0, cache=None): # Shortcut connection for attention block shortcut = x x = self.norm1(x) - x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size] + x, next_cache = self.att(x, mask, cos, sin, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size] x = x + shortcut # Add the original input back # Shortcut connection for feed-forward block @@ -180,7 +179,7 @@ class GroupedQueryAttention(nn.Module): self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) - def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None): + def forward(self, x, mask, cos, sin, start_pos=0, cache=None): b, num_tokens, _ = x.shape # Apply projections @@ -197,18 +196,15 @@ class GroupedQueryAttention(nn.Module): queries = apply_rope(queries, cos, sin, offset=start_pos) keys_new = apply_rope(keys_new, cos, sin, offset=start_pos) - if use_cache: - if cache is None: - keys = keys_new - values = values_new - else: - prev_k, prev_v = cache - keys = torch.cat([prev_k, keys_new], dim=2) - values = torch.cat([prev_v, values_new], dim=2) + if cache is not None: + prev_k, prev_v = cache + keys = torch.cat([prev_k, keys_new], dim=2) + values = torch.cat([prev_v, values_new], dim=2) next_cache = (keys, values) else: + start_pos = 0 # reset RoPE keys, values = keys_new, values_new - next_cache = None + next_cache = (keys, values) # Expand keys and values to match the number of heads # Shape: (b, num_heads, num_tokens, head_dim) @@ -226,7 +222,7 @@ class GroupedQueryAttention(nn.Module): attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head # Use the mask to fill attention scores - attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf) + attn_scores = attn_scores.masked_fill(mask, -torch.inf) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) assert keys.shape[-1] == self.head_dim @@ -286,7 +282,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_c return cos, sin -def apply_rope(x, cos, sin, offset=9): +def apply_rope(x, cos, sin, offset=0): # x: (batch_size, num_heads, seq_len, head_dim) batch_size, num_heads, seq_len, head_dim = x.shape assert head_dim % 2 == 0, "Head dimension must be even" diff --git a/pkg/llms_from_scratch/kv_cache/qwen3.py b/pkg/llms_from_scratch/kv_cache/qwen3.py index 7f7ff8d..cb60112 100644 --- a/pkg/llms_from_scratch/kv_cache/qwen3.py +++ b/pkg/llms_from_scratch/kv_cache/qwen3.py @@ -44,13 +44,13 @@ class Qwen3Model(nn.Module): self.cfg = cfg self.current_pos = 0 # Track current position in KV cache - def forward(self, in_idx, use_cache=False, cache=None): + def forward(self, in_idx, cache=None): # Forward pass tok_embeds = self.tok_emb(in_idx) x = tok_embeds num_tokens = x.shape[1] - if use_cache: + if cache is not None: pos_start = self.current_pos pos_end = pos_start + num_tokens self.current_pos = pos_end @@ -69,10 +69,9 @@ class Qwen3Model(nn.Module): for i, block in enumerate(self.trf_blocks): blk_cache = cache.get(i) if cache else None x, new_blk_cache = block(x, mask, self.cos, self.sin, - use_cache=use_cache, start_pos=pos_start, cache=blk_cache) - if cache: + if cache is not None: cache.update(i, new_blk_cache) next_cache.append(new_blk_cache) @@ -99,11 +98,11 @@ class TransformerBlock(nn.Module): self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6) self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6) - def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None): + def forward(self, x, mask, cos, sin, start_pos=0, cache=None): # Shortcut connection for attention block shortcut = x x = self.norm1(x) - x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size] + x, next_cache = self.att(x, mask, cos, sin, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size] x = x + shortcut # Add the original input back # Shortcut connection for feed-forward block @@ -159,7 +158,7 @@ class GroupedQueryAttention(nn.Module): else: self.q_norm = self.k_norm = None - def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None): + def forward(self, x, mask, cos, sin, start_pos=0, cache=None): b, num_tokens, _ = x.shape # Apply projections @@ -182,18 +181,15 @@ class GroupedQueryAttention(nn.Module): queries = apply_rope(queries, cos, sin, offset=start_pos) keys_new = apply_rope(keys_new, cos, sin, offset=start_pos) - if use_cache: - if cache is None: - keys = keys_new - values = values_new - else: - prev_k, prev_v = cache - keys = torch.cat([prev_k, keys_new], dim=2) - values = torch.cat([prev_v, values_new], dim=2) + if cache is not None: + prev_k, prev_v = cache + keys = torch.cat([prev_k, keys_new], dim=2) + values = torch.cat([prev_v, values_new], dim=2) next_cache = (keys, values) else: + start_pos = 0 # reset RoPE keys, values = keys_new, values_new - next_cache = None + next_cache = (keys, values) # Expand K and V to match number of heads keys = keys.repeat_interleave(self.group_size, dim=1)