mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-27 18:10:39 +00:00
add tests
This commit is contained in:
parent
ffc5e4e5d6
commit
713a6e24c9
@ -263,33 +263,27 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
||||
model.eval()
|
||||
|
||||
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
||||
if use_cache:
|
||||
# Init cache with full prompt
|
||||
model.reset_kv_cache()
|
||||
with torch.no_grad():
|
||||
ctx_len = context_size or model.pos_emb.num_embeddings
|
||||
|
||||
with torch.no_grad():
|
||||
if use_cache:
|
||||
model.reset_kv_cache()
|
||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# a) pick the token with the highest log-probability (greedy sampling)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
# b) append it to the running sequence
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
# c) feed model only the new token
|
||||
with torch.no_grad():
|
||||
for _ in range(max_new_tokens):
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
logits = model(next_idx, use_cache=True)
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
with torch.no_grad():
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
||||
####################################################
|
||||
|
@ -171,7 +171,8 @@ class TransformerBlock(nn.Module):
|
||||
num_heads=cfg["n_heads"],
|
||||
dropout=cfg["drop_rate"],
|
||||
qkv_bias=cfg["qkv_bias"],
|
||||
window_size=cfg["kv_window_size"]) # NEW
|
||||
window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = LayerNorm(cfg["emb_dim"])
|
||||
self.norm2 = LayerNorm(cfg["emb_dim"])
|
||||
@ -289,30 +290,25 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
|
||||
####################################################
|
||||
# NEW
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
||||
def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
||||
model.eval()
|
||||
|
||||
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
||||
if use_cache:
|
||||
# Init cache with full prompt
|
||||
model.reset_kv_cache()
|
||||
with torch.no_grad():
|
||||
ctx_len = context_size or model.pos_emb.num_embeddings
|
||||
|
||||
with torch.no_grad():
|
||||
if use_cache:
|
||||
model.reset_kv_cache()
|
||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# a) pick the token with the highest log-probability (greedy sampling)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
# b) append it to the running sequence
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
# c) feed model only the new token
|
||||
with torch.no_grad():
|
||||
for _ in range(max_new_tokens):
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
logits = model(next_idx, use_cache=True)
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
with torch.no_grad():
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
return idx
|
||||
####################################################
|
||||
|
103
ch04/03_kv-cache/tests.py
Normal file
103
ch04/03_kv-cache/tests.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Code to test the GPT model implementation against the KV cache variants
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import time
|
||||
import tiktoken
|
||||
|
||||
from gpt_ch04 import GPTModel as GPTModelBase
|
||||
from gpt_ch04 import generate_text_simple
|
||||
|
||||
from gpt_with_kv_cache import GPTModel as GPTModelKV1
|
||||
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
|
||||
from gpt_with_kv_cache import generate_text_simple_cached
|
||||
|
||||
|
||||
GPT_CONFIG_124M = {
|
||||
"vocab_size": 50257,
|
||||
"context_length": 1024,
|
||||
"emb_dim": 768,
|
||||
"n_heads": 12,
|
||||
"n_layers": 12,
|
||||
"drop_rate": 0.1,
|
||||
"qkv_bias": False,
|
||||
}
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
|
||||
def test_gpt_model_equivalence_not_cached(ModelClass):
|
||||
torch.manual_seed(123)
|
||||
|
||||
model = ModelClass(GPT_CONFIG_124M).to(device)
|
||||
model.eval()
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
prompt = "Hello, I am"
|
||||
encoded = tokenizer.encode(prompt)
|
||||
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
|
||||
|
||||
model_name = ModelClass.__module__ + "." + ModelClass.__name__
|
||||
|
||||
token_ids = generate_text_simple(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
max_new_tokens=30,
|
||||
context_size=GPT_CONFIG_124M["context_length"]
|
||||
)
|
||||
|
||||
if not hasattr(test_gpt_model_equivalence_not_cached, "results"):
|
||||
test_gpt_model_equivalence_not_cached.results = []
|
||||
|
||||
test_gpt_model_equivalence_not_cached.results.append((model_name, token_ids))
|
||||
|
||||
if len(test_gpt_model_equivalence_not_cached.results) == 3:
|
||||
base_name, base_output = test_gpt_model_equivalence_not_cached.results[0]
|
||||
for other_name, other_output in test_gpt_model_equivalence_not_cached.results[1:]:
|
||||
assert torch.equal(base_output, other_output), (
|
||||
f"Mismatch between {base_name} and {other_name}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
|
||||
def test_gpt_model_equivalence_cached(ModelClass):
|
||||
torch.manual_seed(123)
|
||||
|
||||
model = ModelClass(GPT_CONFIG_124M).to(device)
|
||||
model.eval()
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
prompt = "Hello, I am"
|
||||
encoded_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
|
||||
|
||||
model_name = ModelClass.__module__ + "." + ModelClass.__name__
|
||||
|
||||
if ModelClass is GPTModelBase:
|
||||
token_ids = generate_text_simple(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
max_new_tokens=30,
|
||||
context_size=GPT_CONFIG_124M["context_length"]
|
||||
)
|
||||
else:
|
||||
token_ids = generate_text_simple_cached(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
max_new_tokens=30,
|
||||
context_size=GPT_CONFIG_124M["context_length"]
|
||||
)
|
||||
|
||||
if not hasattr(test_gpt_model_equivalence_cached, "results"):
|
||||
test_gpt_model_equivalence_cached.results = []
|
||||
|
||||
test_gpt_model_equivalence_cached.results.append((model_name, token_ids))
|
||||
|
||||
if len(test_gpt_model_equivalence_cached.results) == 3:
|
||||
base_name, base_output = test_gpt_model_equivalence_cached.results[0]
|
||||
for other_name, other_output in test_gpt_model_equivalence_cached.results[1:]:
|
||||
assert torch.equal(base_output, other_output), (
|
||||
f"Mismatch between {base_name} and {other_name}"
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user