mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	Fix bug in masking when kv cache is used. (#697)
* Fix bug in masking when kv cache is used. * add tests * dd tests * upd * add kv cache test to gh workflow * explicit mask slicing * upd --------- Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
		
							parent
							
								
									e9ffdbace4
								
							
						
					
					
						commit
						ad16b1fbee
					
				
							
								
								
									
										1
									
								
								.github/workflows/basic-tests-linux-uv.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/basic-tests-linux-uv.yml
									
									
									
									
										vendored
									
									
								
							@ -49,6 +49,7 @@ jobs:
 | 
			
		||||
          source .venv/bin/activate
 | 
			
		||||
          pytest --ruff setup/02_installing-python-libraries/tests.py
 | 
			
		||||
          pytest --ruff ch04/01_main-chapter-code/tests.py
 | 
			
		||||
          pytest --ruff ch04/03_kv-cache/tests.py
 | 
			
		||||
          pytest --ruff ch05/01_main-chapter-code/tests.py
 | 
			
		||||
          pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
 | 
			
		||||
          pytest --ruff ch06/01_main-chapter-code/tests.py
 | 
			
		||||
 | 
			
		||||
@ -86,6 +86,18 @@ def forward(self, x, use_cache=False):
 | 
			
		||||
        keys, values = self.cache_k, self.cache_v
 | 
			
		||||
    else:
 | 
			
		||||
        keys, values = keys_new, values_new
 | 
			
		||||
        
 | 
			
		||||
    # ...
 | 
			
		||||
    
 | 
			
		||||
    num_tokens_Q = queries.shape[-2]
 | 
			
		||||
    num_tokens_K = keys.shape[-2]
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        mask_bool = self.mask.bool()[
 | 
			
		||||
            self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
 | 
			
		||||
        ]
 | 
			
		||||
        self.ptr_current_pos += num_tokens_Q
 | 
			
		||||
    else:
 | 
			
		||||
        mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@ -98,6 +110,7 @@ When generating texts, between independent sequences (for instance to text gener
 | 
			
		||||
```python
 | 
			
		||||
def reset_cache(self):
 | 
			
		||||
    self.cache_k, self.cache_v = None, None
 | 
			
		||||
    self.ptr_current_pos = 0
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@ -157,30 +170,29 @@ def reset_kv_cache(self):
 | 
			
		||||
With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
 | 
			
		||||
def generate_text_simple_cached(model, idx, max_new_tokens, 
 | 
			
		||||
                                context_size=None, use_cache=True):
 | 
			
		||||
    model.eval()
 | 
			
		||||
    ctx_len = context_size or model.pos_emb.num_embeddings
 | 
			
		||||
 | 
			
		||||
    ctx_len = model.pos_emb.num_embeddings  # max supported length, e.g. 1024
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        # Init cache with full prompt
 | 
			
		||||
        model.reset_kv_cache()
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            # Init cache with full prompt
 | 
			
		||||
            model.reset_kv_cache()
 | 
			
		||||
            logits = model(idx[:, -ctx_len:], use_cache=True)
 | 
			
		||||
 | 
			
		||||
        for _ in range(max_new_tokens):
 | 
			
		||||
            # a) pick the token with the highest log-probability (greedy sampling)
 | 
			
		||||
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
            # b) append it to the running sequence
 | 
			
		||||
            idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
            # c) feed model only the new token
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            for _ in range(max_new_tokens):
 | 
			
		||||
                # a) pick the token with the highest log-probability (greedy sampling)
 | 
			
		||||
                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
                # b) append it to the running sequence
 | 
			
		||||
                idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
                # c) feed model only the new token
 | 
			
		||||
                logits = model(next_idx, use_cache=True)
 | 
			
		||||
    else:
 | 
			
		||||
        for _ in range(max_new_tokens):
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
        else:
 | 
			
		||||
            for _ in range(max_new_tokens):
 | 
			
		||||
                logits = model(idx[:, -ctx_len:], use_cache=False)
 | 
			
		||||
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
            idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
                idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
 | 
			
		||||
    return idx
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,7 @@ class MultiHeadAttention(nn.Module):
 | 
			
		||||
        self.dropout = nn.Dropout(dropout)
 | 
			
		||||
        self.register_buffer(
 | 
			
		||||
            "mask",
 | 
			
		||||
            torch.triu(torch.ones(context_length, context_length),diagonal=1),
 | 
			
		||||
            torch.triu(torch.ones(context_length, context_length), diagonal=1),
 | 
			
		||||
            persistent=False
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -35,6 +35,7 @@ class MultiHeadAttention(nn.Module):
 | 
			
		||||
        # NEW
 | 
			
		||||
        self.register_buffer("cache_k", None, persistent=False)
 | 
			
		||||
        self.register_buffer("cache_v", None, persistent=False)
 | 
			
		||||
        self.ptr_current_pos = 0
 | 
			
		||||
        ####################################################
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, use_cache=False):
 | 
			
		||||
@ -71,8 +72,19 @@ class MultiHeadAttention(nn.Module):
 | 
			
		||||
        # Compute scaled dot-product attention (aka self-attention) with a causal mask
 | 
			
		||||
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
 | 
			
		||||
 | 
			
		||||
        ####################################################
 | 
			
		||||
        # NEW
 | 
			
		||||
        num_tokens_Q = queries.shape[-2]
 | 
			
		||||
        num_tokens_K = keys.shape[-2]
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            mask_bool = self.mask.bool()[
 | 
			
		||||
                self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
 | 
			
		||||
            ]
 | 
			
		||||
            self.ptr_current_pos += num_tokens_Q
 | 
			
		||||
        ####################################################
 | 
			
		||||
        # Original mask truncated to the number of tokens and converted to boolean
 | 
			
		||||
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
 | 
			
		||||
        else:
 | 
			
		||||
            mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
 | 
			
		||||
 | 
			
		||||
        # Use the mask to fill attention scores
 | 
			
		||||
        attn_scores.masked_fill_(mask_bool, -torch.inf)
 | 
			
		||||
@ -93,6 +105,7 @@ class MultiHeadAttention(nn.Module):
 | 
			
		||||
    # NEW
 | 
			
		||||
    def reset_cache(self):
 | 
			
		||||
        self.cache_k, self.cache_v = None, None
 | 
			
		||||
        self.ptr_current_pos = 0
 | 
			
		||||
    ####################################################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -264,30 +277,29 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
 | 
			
		||||
 | 
			
		||||
####################################################
 | 
			
		||||
# NEW
 | 
			
		||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
 | 
			
		||||
def generate_text_simple_cached(model, idx, max_new_tokens,
 | 
			
		||||
                                context_size=None, use_cache=True):
 | 
			
		||||
    model.eval()
 | 
			
		||||
    ctx_len = context_size or model.pos_emb.num_embeddings
 | 
			
		||||
 | 
			
		||||
    ctx_len = model.pos_emb.num_embeddings  # max supported length, e.g. 1024
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        # Init cache with full prompt
 | 
			
		||||
        model.reset_kv_cache()
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            # Init cache with full prompt
 | 
			
		||||
            model.reset_kv_cache()
 | 
			
		||||
            logits = model(idx[:, -ctx_len:], use_cache=True)
 | 
			
		||||
 | 
			
		||||
        for _ in range(max_new_tokens):
 | 
			
		||||
            # a) pick the token with the highest log-probability (greedy sampling)
 | 
			
		||||
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
            # b) append it to the running sequence
 | 
			
		||||
            idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
            # c) feed model only the new token
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            for _ in range(max_new_tokens):
 | 
			
		||||
                # a) pick the token with the highest log-probability (greedy sampling)
 | 
			
		||||
                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
                # b) append it to the running sequence
 | 
			
		||||
                idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
                # c) feed model only the new token
 | 
			
		||||
                logits = model(next_idx, use_cache=True)
 | 
			
		||||
    else:
 | 
			
		||||
        for _ in range(max_new_tokens):
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
        else:
 | 
			
		||||
            for _ in range(max_new_tokens):
 | 
			
		||||
                logits = model(idx[:, -ctx_len:], use_cache=False)
 | 
			
		||||
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
            idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
                idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
 | 
			
		||||
    return idx
 | 
			
		||||
####################################################
 | 
			
		||||
 | 
			
		||||
@ -171,7 +171,8 @@ class TransformerBlock(nn.Module):
 | 
			
		||||
            num_heads=cfg["n_heads"],
 | 
			
		||||
            dropout=cfg["drop_rate"],
 | 
			
		||||
            qkv_bias=cfg["qkv_bias"],
 | 
			
		||||
            window_size=cfg["kv_window_size"])  # NEW
 | 
			
		||||
            window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"]   # NEW
 | 
			
		||||
        )
 | 
			
		||||
        self.ff = FeedForward(cfg)
 | 
			
		||||
        self.norm1 = LayerNorm(cfg["emb_dim"])
 | 
			
		||||
        self.norm2 = LayerNorm(cfg["emb_dim"])
 | 
			
		||||
@ -289,30 +290,25 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
 | 
			
		||||
 | 
			
		||||
####################################################
 | 
			
		||||
# NEW
 | 
			
		||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
 | 
			
		||||
def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True):
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    ctx_len = model.pos_emb.num_embeddings  # max supported length, e.g. 1024
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        # Init cache with full prompt
 | 
			
		||||
        model.reset_kv_cache()
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
    ctx_len = context_size or model.pos_emb.num_embeddings
 | 
			
		||||
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            model.reset_kv_cache()
 | 
			
		||||
            logits = model(idx[:, -ctx_len:], use_cache=True)
 | 
			
		||||
 | 
			
		||||
        for _ in range(max_new_tokens):
 | 
			
		||||
            # a) pick the token with the highest log-probability (greedy sampling)
 | 
			
		||||
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
            # b) append it to the running sequence
 | 
			
		||||
            idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
            # c) feed model only the new token
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
            for _ in range(max_new_tokens):
 | 
			
		||||
                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
                idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
                logits = model(next_idx, use_cache=True)
 | 
			
		||||
    else:
 | 
			
		||||
        for _ in range(max_new_tokens):
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
        else:
 | 
			
		||||
            for _ in range(max_new_tokens):
 | 
			
		||||
                logits = model(idx[:, -ctx_len:], use_cache=False)
 | 
			
		||||
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
            idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
 | 
			
		||||
                idx = torch.cat([idx, next_idx], dim=1)
 | 
			
		||||
 | 
			
		||||
    return idx
 | 
			
		||||
####################################################
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										101
									
								
								ch04/03_kv-cache/tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								ch04/03_kv-cache/tests.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,101 @@
 | 
			
		||||
# Code to test the GPT model implementation against the KV cache variants
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
import tiktoken
 | 
			
		||||
 | 
			
		||||
from gpt_ch04 import GPTModel as GPTModelBase
 | 
			
		||||
from gpt_ch04 import generate_text_simple
 | 
			
		||||
 | 
			
		||||
from gpt_with_kv_cache import GPTModel as GPTModelKV1
 | 
			
		||||
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
 | 
			
		||||
from gpt_with_kv_cache import generate_text_simple_cached
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
GPT_CONFIG_124M = {
 | 
			
		||||
    "vocab_size": 50257,
 | 
			
		||||
    "context_length": 1024,
 | 
			
		||||
    "emb_dim": 768,
 | 
			
		||||
    "n_heads": 12,
 | 
			
		||||
    "n_layers": 12,
 | 
			
		||||
    "drop_rate": 0.1,
 | 
			
		||||
    "qkv_bias": False,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
 | 
			
		||||
def test_gpt_model_equivalence_not_cached(ModelClass):
 | 
			
		||||
    torch.manual_seed(123)
 | 
			
		||||
 | 
			
		||||
    model = ModelClass(GPT_CONFIG_124M).to(device)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    tokenizer = tiktoken.get_encoding("gpt2")
 | 
			
		||||
    prompt = "Hello, I am"
 | 
			
		||||
    encoded = tokenizer.encode(prompt)
 | 
			
		||||
    encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
    model_name = ModelClass.__module__ + "." + ModelClass.__name__
 | 
			
		||||
 | 
			
		||||
    token_ids = generate_text_simple(
 | 
			
		||||
        model=model,
 | 
			
		||||
        idx=encoded_tensor,
 | 
			
		||||
        max_new_tokens=30,
 | 
			
		||||
        context_size=GPT_CONFIG_124M["context_length"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if not hasattr(test_gpt_model_equivalence_not_cached, "results"):
 | 
			
		||||
        test_gpt_model_equivalence_not_cached.results = []
 | 
			
		||||
 | 
			
		||||
    test_gpt_model_equivalence_not_cached.results.append((model_name, token_ids))
 | 
			
		||||
 | 
			
		||||
    if len(test_gpt_model_equivalence_not_cached.results) == 3:
 | 
			
		||||
        base_name, base_output = test_gpt_model_equivalence_not_cached.results[0]
 | 
			
		||||
        for other_name, other_output in test_gpt_model_equivalence_not_cached.results[1:]:
 | 
			
		||||
            assert torch.equal(base_output, other_output), (
 | 
			
		||||
                f"Mismatch between {base_name} and {other_name}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
 | 
			
		||||
def test_gpt_model_equivalence_cached(ModelClass):
 | 
			
		||||
    torch.manual_seed(123)
 | 
			
		||||
 | 
			
		||||
    model = ModelClass(GPT_CONFIG_124M).to(device)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    tokenizer = tiktoken.get_encoding("gpt2")
 | 
			
		||||
    prompt = "Hello, I am"
 | 
			
		||||
    encoded_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
    model_name = ModelClass.__module__ + "." + ModelClass.__name__
 | 
			
		||||
 | 
			
		||||
    if ModelClass is GPTModelBase:
 | 
			
		||||
        token_ids = generate_text_simple(
 | 
			
		||||
            model=model,
 | 
			
		||||
            idx=encoded_tensor,
 | 
			
		||||
            max_new_tokens=30,
 | 
			
		||||
            context_size=GPT_CONFIG_124M["context_length"]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        token_ids = generate_text_simple_cached(
 | 
			
		||||
            model=model,
 | 
			
		||||
            idx=encoded_tensor,
 | 
			
		||||
            max_new_tokens=30,
 | 
			
		||||
            context_size=GPT_CONFIG_124M["context_length"]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if not hasattr(test_gpt_model_equivalence_cached, "results"):
 | 
			
		||||
        test_gpt_model_equivalence_cached.results = []
 | 
			
		||||
 | 
			
		||||
    test_gpt_model_equivalence_cached.results.append((model_name, token_ids))
 | 
			
		||||
 | 
			
		||||
    if len(test_gpt_model_equivalence_cached.results) == 3:
 | 
			
		||||
        base_name, base_output = test_gpt_model_equivalence_cached.results[0]
 | 
			
		||||
        for other_name, other_output in test_gpt_model_equivalence_cached.results[1:]:
 | 
			
		||||
            assert torch.equal(base_output, other_output), (
 | 
			
		||||
                f"Mismatch between {base_name} and {other_name}"
 | 
			
		||||
            )
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user