2025-03-23 19:28:49 -05:00
|
|
|
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
|
|
|
# Source for "Build a Large Language Model From Scratch"
|
|
|
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
|
|
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
|
|
|
|
2025-03-27 14:00:25 -05:00
|
|
|
from llms_from_scratch.ch04 import GPTModel, GPTModelFast
|
2025-03-23 19:28:49 -05:00
|
|
|
from llms_from_scratch.ch04 import generate_text_simple
|
|
|
|
|
2025-03-27 14:00:25 -05:00
|
|
|
import pytest
|
2025-03-23 19:28:49 -05:00
|
|
|
import torch
|
|
|
|
import tiktoken
|
|
|
|
|
|
|
|
|
2025-03-27 14:00:25 -05:00
|
|
|
GPT_CONFIG_124M = {
|
|
|
|
"vocab_size": 50257, # Vocabulary size
|
|
|
|
"context_length": 1024, # Context length
|
|
|
|
"emb_dim": 768, # Embedding dimension
|
|
|
|
"n_heads": 12, # Number of attention heads
|
|
|
|
"n_layers": 12, # Number of layers
|
|
|
|
"drop_rate": 0.1, # Dropout rate
|
|
|
|
"qkv_bias": False # Query-Key-Value bias
|
|
|
|
}
|
2025-03-23 19:28:49 -05:00
|
|
|
|
2025-03-27 14:00:25 -05:00
|
|
|
|
|
|
|
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
|
|
|
|
def test_gpt_model_variants(ModelClass):
|
2025-03-23 19:28:49 -05:00
|
|
|
torch.manual_seed(123)
|
2025-03-27 14:00:25 -05:00
|
|
|
model = ModelClass(GPT_CONFIG_124M)
|
2025-03-23 19:28:49 -05:00
|
|
|
model.eval() # disable dropout
|
|
|
|
|
|
|
|
start_context = "Hello, I am"
|
|
|
|
|
|
|
|
tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
encoded = tokenizer.encode(start_context)
|
|
|
|
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
|
|
|
|
|
|
|
|
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
|
|
|
|
print("\nInput text:", start_context)
|
|
|
|
print("Encoded input text:", encoded)
|
|
|
|
print("encoded_tensor.shape:", encoded_tensor.shape)
|
|
|
|
|
|
|
|
out = generate_text_simple(
|
|
|
|
model=model,
|
|
|
|
idx=encoded_tensor,
|
|
|
|
max_new_tokens=10,
|
|
|
|
context_size=GPT_CONFIG_124M["context_length"]
|
|
|
|
)
|
|
|
|
|
|
|
|
expect = torch.tensor([
|
|
|
|
[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267,
|
|
|
|
49706, 43231, 47062, 34657]
|
|
|
|
])
|
2025-03-27 14:00:25 -05:00
|
|
|
assert torch.equal(expect, out), "Generated output does not match expected output"
|