restore file

This commit is contained in:
rasbt 2024-06-03 07:17:56 -05:00
parent d51099a9e7
commit 089dfb756a

View File

@ -11,12 +11,13 @@ import tiktoken
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
##################################### #####################################
# Chapter 2 # Chapter 2
##################################### #####################################
class GPTDatasetV1(Dataset): class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride): def __init__(self, txt, tokenizer, max_length, stride):
self.input_ids = [] self.input_ids = []
@ -57,10 +58,11 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
##################################### #####################################
# Chapter 3 # Chapter 3
##################################### #####################################
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__() super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads" assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
self.d_out = d_out self.d_out = d_out
self.num_heads = num_heads self.num_heads = num_heads
@ -107,7 +109,7 @@ class MultiHeadAttention(nn.Module):
context_vec = (attn_weights @ values).transpose(1, 2) context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim # Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection context_vec = self.out_proj(context_vec) # optional projection
return context_vec return context_vec
@ -116,6 +118,7 @@ class MultiHeadAttention(nn.Module):
##################################### #####################################
# Chapter 4 # Chapter 4
##################################### #####################################
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, emb_dim): def __init__(self, emb_dim):
super().__init__() super().__init__()
@ -238,42 +241,82 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
return idx return idx
if __name__ == "__main__": #####################################
# Chapter 5
####################################
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
torch.manual_seed(123) def calc_loss_batch(input_batch, target_batch, model, device):
model = GPTModel(GPT_CONFIG_124M) input_batch, target_batch = input_batch.to(device), target_batch.to(device)
model.eval() # disable dropout logits = model(input_batch)
loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
return loss
start_context = "Hello, I am"
tokenizer = tiktoken.get_encoding("gpt2") def calc_loss_loader(data_loader, model, device, num_batches=None):
encoded = tokenizer.encode(start_context) total_loss = 0.
encoded_tensor = torch.tensor(encoded).unsqueeze(0) if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
else:
break
return total_loss / num_batches
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
print("\nInput text:", start_context)
print("Encoded input text:", encoded)
print("encoded_tensor.shape:", encoded_tensor.shape)
out = generate_text_simple( def evaluate_model(model, train_loader, val_loader, device, eval_iter):
model=model, model.eval()
idx=encoded_tensor, with torch.no_grad():
max_new_tokens=10, train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
context_size=GPT_CONFIG_124M["context_length"] val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
) model.train()
decoded_text = tokenizer.decode(out.squeeze(0).tolist()) return train_loss, val_loss
print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
print("\nOutput:", out) def generate_and_print_sample(model, tokenizer, device, start_context):
print("Output length:", len(out[0])) model.eval()
print("Output text:", decoded_text) context_size = model.pos_emb.weight.shape[0]
encoded = text_to_token_ids(start_context, tokenizer).to(device)
with torch.no_grad():
token_ids = generate_text_simple(
model=model, idx=encoded,
max_new_tokens=50, context_size=context_size)
decoded_text = token_ids_to_text(token_ids, tokenizer)
print(decoded_text.replace("\n", " ")) # Compact print format
model.train()
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
fig, ax1 = plt.subplots(figsize=(5, 3))
# Plot training and validation loss against epochs
ax1.plot(epochs_seen, train_losses, label="Training loss")
ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax1.legend(loc="upper right")
# Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
ax2.set_xlabel("Tokens seen")
fig.tight_layout() # Adjust layout to make room
# plt.show()
def text_to_token_ids(text, tokenizer):
encoded = tokenizer.encode(text)
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
return encoded_tensor
def token_ids_to_text(token_ids, tokenizer):
flat = token_ids.squeeze(0) # remove batch dimension
return tokenizer.decode(flat.tolist())