diff --git a/.github/workflows/basic-tests-linux-uv.yml b/.github/workflows/basic-tests-linux-uv.yml index ab651f0..d2b9cc4 100644 --- a/.github/workflows/basic-tests-linux-uv.yml +++ b/.github/workflows/basic-tests-linux-uv.yml @@ -49,6 +49,7 @@ jobs: source .venv/bin/activate pytest --ruff setup/02_installing-python-libraries/tests.py pytest --ruff ch04/01_main-chapter-code/tests.py + pytest --ruff ch04/03_kv-cache/tests.py pytest --ruff ch05/01_main-chapter-code/tests.py pytest --ruff ch05/07_gpt_to_llama/tests/tests.py pytest --ruff ch06/01_main-chapter-code/tests.py diff --git a/ch04/03_kv-cache/README.md b/ch04/03_kv-cache/README.md index 2284d14..5cb1d6c 100644 --- a/ch04/03_kv-cache/README.md +++ b/ch04/03_kv-cache/README.md @@ -86,6 +86,18 @@ def forward(self, x, use_cache=False): keys, values = self.cache_k, self.cache_v else: keys, values = keys_new, values_new + + # ... + + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + if use_cache: + mask_bool = self.mask.bool()[ + self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K + ] + self.ptr_current_pos += num_tokens_Q + else: + mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K] ```   @@ -98,6 +110,7 @@ When generating texts, between independent sequences (for instance to text gener ```python def reset_cache(self): self.cache_k, self.cache_v = None, None + self.ptr_current_pos = 0 ```   @@ -157,30 +170,29 @@ def reset_kv_cache(self): With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function: ```python -def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True): +def generate_text_simple_cached(model, idx, max_new_tokens, + context_size=None, use_cache=True): model.eval() + ctx_len = context_size or model.pos_emb.num_embeddings - ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024 - if use_cache: - # Init cache with full prompt - model.reset_kv_cache() - with torch.no_grad(): + with torch.no_grad(): + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() logits = model(idx[:, -ctx_len:], use_cache=True) - for _ in range(max_new_tokens): - # a) pick the token with the highest log-probability (greedy sampling) - next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) - # b) append it to the running sequence - idx = torch.cat([idx, next_idx], dim=1) - # c) feed model only the new token - with torch.no_grad(): + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + # b) append it to the running sequence + idx = torch.cat([idx, next_idx], dim=1) + # c) feed model only the new token logits = model(next_idx, use_cache=True) - else: - for _ in range(max_new_tokens): - with torch.no_grad(): + else: + for _ in range(max_new_tokens): logits = model(idx[:, -ctx_len:], use_cache=False) - next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) - idx = torch.cat([idx, next_idx], dim=1) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) return idx ``` diff --git a/ch04/03_kv-cache/gpt_with_kv_cache.py b/ch04/03_kv-cache/gpt_with_kv_cache.py index c685fe0..760605b 100644 --- a/ch04/03_kv-cache/gpt_with_kv_cache.py +++ b/ch04/03_kv-cache/gpt_with_kv_cache.py @@ -27,7 +27,7 @@ class MultiHeadAttention(nn.Module): self.dropout = nn.Dropout(dropout) self.register_buffer( "mask", - torch.triu(torch.ones(context_length, context_length),diagonal=1), + torch.triu(torch.ones(context_length, context_length), diagonal=1), persistent=False ) @@ -35,6 +35,7 @@ class MultiHeadAttention(nn.Module): # NEW self.register_buffer("cache_k", None, persistent=False) self.register_buffer("cache_v", None, persistent=False) + self.ptr_current_pos = 0 #################################################### def forward(self, x, use_cache=False): @@ -71,8 +72,19 @@ class MultiHeadAttention(nn.Module): # Compute scaled dot-product attention (aka self-attention) with a causal mask attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + #################################################### + # NEW + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + if use_cache: + mask_bool = self.mask.bool()[ + self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K + ] + self.ptr_current_pos += num_tokens_Q + #################################################### # Original mask truncated to the number of tokens and converted to boolean - mask_bool = self.mask.bool()[:num_tokens, :num_tokens] + else: + mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K] # Use the mask to fill attention scores attn_scores.masked_fill_(mask_bool, -torch.inf) @@ -93,6 +105,7 @@ class MultiHeadAttention(nn.Module): # NEW def reset_cache(self): self.cache_k, self.cache_v = None, None + self.ptr_current_pos = 0 #################################################### @@ -264,30 +277,29 @@ def generate_text_simple(model, idx, max_new_tokens, context_size): #################################################### # NEW -def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True): +def generate_text_simple_cached(model, idx, max_new_tokens, + context_size=None, use_cache=True): model.eval() + ctx_len = context_size or model.pos_emb.num_embeddings - ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024 - if use_cache: - # Init cache with full prompt - model.reset_kv_cache() - with torch.no_grad(): + with torch.no_grad(): + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() logits = model(idx[:, -ctx_len:], use_cache=True) - for _ in range(max_new_tokens): - # a) pick the token with the highest log-probability (greedy sampling) - next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) - # b) append it to the running sequence - idx = torch.cat([idx, next_idx], dim=1) - # c) feed model only the new token - with torch.no_grad(): + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + # b) append it to the running sequence + idx = torch.cat([idx, next_idx], dim=1) + # c) feed model only the new token logits = model(next_idx, use_cache=True) - else: - for _ in range(max_new_tokens): - with torch.no_grad(): + else: + for _ in range(max_new_tokens): logits = model(idx[:, -ctx_len:], use_cache=False) - next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) - idx = torch.cat([idx, next_idx], dim=1) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) return idx #################################################### diff --git a/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py b/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py index e23df6c..745cac6 100644 --- a/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py +++ b/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py @@ -171,7 +171,8 @@ class TransformerBlock(nn.Module): num_heads=cfg["n_heads"], dropout=cfg["drop_rate"], qkv_bias=cfg["qkv_bias"], - window_size=cfg["kv_window_size"]) # NEW + window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW + ) self.ff = FeedForward(cfg) self.norm1 = LayerNorm(cfg["emb_dim"]) self.norm2 = LayerNorm(cfg["emb_dim"]) @@ -289,30 +290,25 @@ def generate_text_simple(model, idx, max_new_tokens, context_size): #################################################### # NEW -def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True): +def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True): model.eval() - ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024 - if use_cache: - # Init cache with full prompt - model.reset_kv_cache() - with torch.no_grad(): + ctx_len = context_size or model.pos_emb.num_embeddings + + with torch.no_grad(): + if use_cache: + model.reset_kv_cache() logits = model(idx[:, -ctx_len:], use_cache=True) - for _ in range(max_new_tokens): - # a) pick the token with the highest log-probability (greedy sampling) - next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) - # b) append it to the running sequence - idx = torch.cat([idx, next_idx], dim=1) - # c) feed model only the new token - with torch.no_grad(): + for _ in range(max_new_tokens): + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) logits = model(next_idx, use_cache=True) - else: - for _ in range(max_new_tokens): - with torch.no_grad(): + else: + for _ in range(max_new_tokens): logits = model(idx[:, -ctx_len:], use_cache=False) - next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) - idx = torch.cat([idx, next_idx], dim=1) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) return idx #################################################### diff --git a/ch04/03_kv-cache/tests.py b/ch04/03_kv-cache/tests.py new file mode 100644 index 0000000..83aae44 --- /dev/null +++ b/ch04/03_kv-cache/tests.py @@ -0,0 +1,101 @@ +# Code to test the GPT model implementation against the KV cache variants + +import pytest +import torch +import tiktoken + +from gpt_ch04 import GPTModel as GPTModelBase +from gpt_ch04 import generate_text_simple + +from gpt_with_kv_cache import GPTModel as GPTModelKV1 +from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2 +from gpt_with_kv_cache import generate_text_simple_cached + + +GPT_CONFIG_124M = { + "vocab_size": 50257, + "context_length": 1024, + "emb_dim": 768, + "n_heads": 12, + "n_layers": 12, + "drop_rate": 0.1, + "qkv_bias": False, +} + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2]) +def test_gpt_model_equivalence_not_cached(ModelClass): + torch.manual_seed(123) + + model = ModelClass(GPT_CONFIG_124M).to(device) + model.eval() + + tokenizer = tiktoken.get_encoding("gpt2") + prompt = "Hello, I am" + encoded = tokenizer.encode(prompt) + encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0) + + model_name = ModelClass.__module__ + "." + ModelClass.__name__ + + token_ids = generate_text_simple( + model=model, + idx=encoded_tensor, + max_new_tokens=30, + context_size=GPT_CONFIG_124M["context_length"] + ) + + if not hasattr(test_gpt_model_equivalence_not_cached, "results"): + test_gpt_model_equivalence_not_cached.results = [] + + test_gpt_model_equivalence_not_cached.results.append((model_name, token_ids)) + + if len(test_gpt_model_equivalence_not_cached.results) == 3: + base_name, base_output = test_gpt_model_equivalence_not_cached.results[0] + for other_name, other_output in test_gpt_model_equivalence_not_cached.results[1:]: + assert torch.equal(base_output, other_output), ( + f"Mismatch between {base_name} and {other_name}" + ) + + +@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2]) +def test_gpt_model_equivalence_cached(ModelClass): + torch.manual_seed(123) + + model = ModelClass(GPT_CONFIG_124M).to(device) + model.eval() + + tokenizer = tiktoken.get_encoding("gpt2") + prompt = "Hello, I am" + encoded_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0) + + model_name = ModelClass.__module__ + "." + ModelClass.__name__ + + if ModelClass is GPTModelBase: + token_ids = generate_text_simple( + model=model, + idx=encoded_tensor, + max_new_tokens=30, + context_size=GPT_CONFIG_124M["context_length"] + ) + else: + token_ids = generate_text_simple_cached( + model=model, + idx=encoded_tensor, + max_new_tokens=30, + context_size=GPT_CONFIG_124M["context_length"] + ) + + if not hasattr(test_gpt_model_equivalence_cached, "results"): + test_gpt_model_equivalence_cached.results = [] + + test_gpt_model_equivalence_cached.results.append((model_name, token_ids)) + + if len(test_gpt_model_equivalence_cached.results) == 3: + base_name, base_output = test_gpt_model_equivalence_cached.results[0] + for other_name, other_output in test_gpt_model_equivalence_cached.results[1:]: + assert torch.equal(base_output, other_output), ( + f"Mismatch between {base_name} and {other_name}" + )