From f656ef996d162ed3e5fac4f3bf205ea10b0a869b Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Tue, 23 Apr 2024 09:51:52 -0500 Subject: [PATCH] Chapter 6 ablation studies (#127) * Chapter 6 ablation studies * add table * formatting * formatting * formatting --- ch06/01_main-chapter-code/README.md | 1 + ch06/02_additional-experiments/README.md | 10 + .../additional-experiments.py | 393 ++++++++++++++++++ .../02_additional-experiments/gpt_download.py | 99 +++++ .../previous_chapters.py | 345 +++++++++++++++ 5 files changed, 848 insertions(+) create mode 100644 ch06/01_main-chapter-code/README.md create mode 100644 ch06/02_additional-experiments/README.md create mode 100644 ch06/02_additional-experiments/additional-experiments.py create mode 100644 ch06/02_additional-experiments/gpt_download.py create mode 100644 ch06/02_additional-experiments/previous_chapters.py diff --git a/ch06/01_main-chapter-code/README.md b/ch06/01_main-chapter-code/README.md new file mode 100644 index 0000000..9ccd531 --- /dev/null +++ b/ch06/01_main-chapter-code/README.md @@ -0,0 +1 @@ +In progress. \ No newline at end of file diff --git a/ch06/02_additional-experiments/README.md b/ch06/02_additional-experiments/README.md new file mode 100644 index 0000000..b29baad --- /dev/null +++ b/ch06/02_additional-experiments/README.md @@ -0,0 +1,10 @@ +# Additional Experiments + +| Model | Trainable token | Trainable layers | CPU/GPU | Training time | Training acc | Validation acc | Test acc | +|--------------------|-----------------|------------------|---------|---------------|--------------|----------------|----------| +| gpt2-small (124M) | last | last_block | V100 | 0.39 min | 96.63% | 97.99% | 94.33% | +| gpt2-small (124M) | first | last_block | V100 | 0.37 min | 78.46% | 80.54% | 75.00% | +| gpt2-small (124M) | last | last_layer | V100 | 0.33 min | 78.65% | 87.25% | 78.33% | +| gpt2-small (124M) | last | all | V100 | 0.94 min | 99.62% | 96.64% | 96.33% | +| gpt2-medium (355M) | last | last_block | V100 | 0.91 min | 87.50% | 51.01% | 56.67% | +| gpt2-large (774M) | last | last_block | V100 | 1.91 min | 99.52% | 98.66% | 96.67% | \ No newline at end of file diff --git a/ch06/02_additional-experiments/additional-experiments.py b/ch06/02_additional-experiments/additional-experiments.py new file mode 100644 index 0000000..c4414e0 --- /dev/null +++ b/ch06/02_additional-experiments/additional-experiments.py @@ -0,0 +1,393 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import argparse +import os +from pathlib import Path +import time +import urllib.request +import zipfile + +import pandas as pd +import tiktoken +import torch +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + +from gpt_download import download_and_load_gpt2 +from previous_chapters import GPTModel, load_weights_into_gpt + + +class SpamDataset(Dataset): + def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256): + self.data = pd.read_csv(csv_file) + self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer) + + # Pre-tokenize texts + self.encoded_texts = [ + tokenizer.encode(text)[:self.max_length] + for text in self.data["Text"] + ] + # Pad sequences to the longest sequence + self.encoded_texts = [ + et + [pad_token_id] * (self.max_length - len(et)) + for et in self.encoded_texts + ] + + def __getitem__(self, index): + encoded = self.encoded_texts[index] + label = self.data.iloc[index]["Label"] + return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long) + + def __len__(self): + return len(self.data) + + def _longest_encoded_length(self, tokenizer): + max_length = 0 + for text in self.data["Text"]: + encoded_length = len(tokenizer.encode(text)) + if encoded_length > max_length: + max_length = encoded_length + return max_length + + +def download_and_unzip(url, zip_path, extract_to, new_file_path): + if new_file_path.exists(): + print(f"{new_file_path} already exists. Skipping download and extraction.") + return + + # Downloading the file + with urllib.request.urlopen(url) as response: + with open(zip_path, "wb") as out_file: + out_file.write(response.read()) + + # Unzipping the file + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(extract_to) + + # Renaming the file to indicate its format + original_file = Path(extract_to) / "SMSSpamCollection" + os.rename(original_file, new_file_path) + print(f"File downloaded and saved as {new_file_path}") + + +def random_split(df, train_frac, validation_frac): + # Shuffle the entire DataFrame + df = df.sample(frac=1, random_state=123).reset_index(drop=True) + + # Calculate split indices + train_end = int(len(df) * train_frac) + validation_end = train_end + int(len(df) * validation_frac) + + # Split the DataFrame + train_df = df[:train_end] + validation_df = df[train_end:validation_end] + test_df = df[validation_end:] + + return train_df, validation_df, test_df + + +def create_dataset_csvs(data_file_path): + df = pd.read_csv(new_file_path, sep="\t", header=None, names=["Label", "Text"]) + + # Create balanced dataset + n_spam = df[df["Label"] == "spam"].shape[0] + ham_sampled = df[df["Label"] == "ham"].sample(n_spam, random_state=123) + balanced_df = pd.concat([ham_sampled, df[df["Label"] == "spam"]]) + balanced_df = balanced_df.sample(frac=1, random_state=123).reset_index(drop=True) + balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1}) + + # Sample and save csv files + train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1) + train_df.to_csv("train.csv", index=None) + validation_df.to_csv("validation.csv", index=None) + test_df.to_csv("test.csv", index=None) + + +def instantiate_model(choose_model): + + BASE_CONFIG = { + "vocab_size": 50257, # Vocabulary size + "context_length": 1024, # Context length + "drop_rate": 0.0, # Dropout rate + "qkv_bias": True # Query-key-value bias + } + + model_configs = { + "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12}, + "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16}, + "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20}, + "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25}, + } + + BASE_CONFIG.update(model_configs[choose_model]) + + model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")") + settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") + + model = GPTModel(BASE_CONFIG) + load_weights_into_gpt(model, params) + model.eval() + return model + + +def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1): + input_batch, target_batch = input_batch.to(device), target_batch.to(device) + logits = model(input_batch)[:, trainable_token, :] # Logits of last ouput token + loss = torch.nn.functional.cross_entropy(logits, target_batch) + return loss + + +def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1): + total_loss = 0. + if len(data_loader) == 0: + return float("nan") + elif num_batches is None: + num_batches = len(data_loader) + else: + # Reduce the number of batches to match the total number of batches in the data loader + # if num_batches exceeds the number of batches in the data loader + num_batches = min(num_batches, len(data_loader)) + for i, (input_batch, target_batch) in enumerate(data_loader): + if i < num_batches: + loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token) + total_loss += loss.item() + else: + break + return total_loss / num_batches + + +@torch.no_grad() # Disable gradient tracking for efficiency +def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1): + model.eval() + correct_predictions, num_examples = 0, 0 + + if num_batches is None: + num_batches = len(data_loader) + else: + num_batches = min(num_batches, len(data_loader)) + for i, (input_batch, target_batch) in enumerate(data_loader): + if i < num_batches: + input_batch, target_batch = input_batch.to(device), target_batch.to(device) + logits = model(input_batch)[:, trainable_token, :] # Logits of last ouput token + predicted_labels = torch.argmax(logits, dim=-1) + + num_examples += predicted_labels.shape[0] + correct_predictions += (predicted_labels == target_batch).sum().item() + else: + break + return correct_predictions / num_examples + + +def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1): + model.eval() + with torch.no_grad(): + train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) + val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) + model.train() + return train_loss, val_loss + + +def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, + eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1): + # Initialize lists to track losses and tokens seen + train_losses, val_losses, train_accs, val_accs = [], [], [], [] + examples_seen, global_step = 0, -1 + + # Main training loop + for epoch in range(num_epochs): + model.train() # Set model to training mode + + for input_batch, target_batch in train_loader: + optimizer.zero_grad() # Reset loss gradients from previous epoch + loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token) + loss.backward() # Calculate loss gradients + optimizer.step() # Update model weights using loss gradients + examples_seen += input_batch.shape[0] # New: track examples instead of tokens + global_step += 1 + + # Optional evaluation step + if global_step % eval_freq == 0: + train_loss, val_loss = evaluate_model( + model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token) + train_losses.append(train_loss) + val_losses.append(val_loss) + print(f"Ep {epoch+1} (Step {global_step:06d}): " + f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") + + if max_steps is not None and global_step > max_steps: + break + + # New: Calculate accuracy after each epoch + train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) + val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) + print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") + print(f"Validation accuracy: {val_accuracy*100:.2f}%") + train_accs.append(train_accuracy) + val_accs.append(val_accuracy) + + if max_steps is not None and global_step > max_steps: + break + + return train_losses, val_losses, train_accs, val_accs, examples_seen + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_size", + type=str, + default="gpt2-small (124M)", + help=( + "Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)'," + " 'gpt2-large (774M)', 'gpt2-xl (1558M)'." + ) + ) + parser.add_argument( + "--trainable_layers", + type=str, + default="last_block", + help=( + "Which layers to train. Options: 'all', 'last_block', 'last_layer'." + ) + ) + parser.add_argument( + "--trainable_token", + type=str, + default="last", + help=( + "Which token to train. Options: 'first', 'last'." + ) + ) + + args = parser.parse_args() + + if args.trainable_token == "first": + args.trainable_token = 0 + elif args.trainable_token == "last": + args.trainable_token = -1 + else: + raise ValueError("Invalid --trainable_token argument") + + ############################### + # Instantiate dataloaders + ############################### + + url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip" + zip_path = "sms_spam_collection.zip" + extract_to = "sms_spam_collection" + new_file_path = Path(extract_to) / "SMSSpamCollection.tsv" + + base_path = Path(".") + file_names = ["train.csv", "validation.csv", "test.csv"] + all_exist = all((base_path / file_name).exists() for file_name in file_names) + + if not all_exist: + download_and_unzip(url, zip_path, extract_to, new_file_path) + create_dataset_csvs(new_file_path) + + tokenizer = tiktoken.get_encoding("gpt2") + + train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer) + val_dataset = SpamDataset(base_path / "validation.csv", max_length=None, tokenizer=tokenizer) + test_dataset = SpamDataset(base_path / "test.csv", max_length=None, tokenizer=tokenizer) + + tokenizer = tiktoken.get_encoding("gpt2") + + num_workers = 0 + batch_size = 8 + + train_loader = DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + drop_last=True, + ) + + val_loader = DataLoader( + dataset=val_dataset, + batch_size=batch_size, + num_workers=num_workers, + drop_last=False, + ) + + test_loader = DataLoader( + dataset=test_dataset, + batch_size=batch_size, + num_workers=num_workers, + drop_last=False, + ) + + ############################### + # Load model + ############################### + + model = instantiate_model(args.model_size) + for param in model.parameters(): + param.requires_grad = False + + if args.model_size == "gpt2-small (124M)": + in_features = 768 + elif args.model_size == "gpt2-medium (355M)": + in_features = 1024 + elif args.model_size == "gpt2-large (774M)": + in_features = 1280 + elif args.model_size == "gpt2-xl (1558M)": + in_features = 1280 + else: + raise ValueError("Invalid --model_size argument") + + torch.manual_seed(123) + print(model.out_head.weight.shape) + model.out_head = torch.nn.Linear(in_features=in_features, out_features=2) + + if args.trainable_layers == "last_layer": + pass + elif args.trainable_layers == "last_block": + for param in model.trf_blocks[-1].parameters(): + param.requires_grad = True + for param in model.final_norm.parameters(): + param.requires_grad = True + elif args.trainable_layers == "all": + for param in model.parameters(): + param.requires_grad = True + else: + raise ValueError("Invalid --trainable_layers argument.") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + ############################### + # Train model + ############################### + + start_time = time.time() + torch.manual_seed(123) + optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1) + + num_epochs = 5 + train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( + model, train_loader, val_loader, optimizer, device, + num_epochs=num_epochs, eval_freq=50, eval_iter=5, + tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token + ) + + end_time = time.time() + execution_time_minutes = (end_time - start_time) / 60 + print(f"Training completed in {execution_time_minutes:.2f} minutes.") + + ############################### + # Evaluate model + ############################### + + train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token) + val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token) + test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token) + + print(f"Training accuracy: {train_accuracy*100:.2f}%") + print(f"Validation accuracy: {val_accuracy*100:.2f}%") + print(f"Test accuracy: {test_accuracy*100:.2f}%") diff --git a/ch06/02_additional-experiments/gpt_download.py b/ch06/02_additional-experiments/gpt_download.py new file mode 100644 index 0000000..0d695d2 --- /dev/null +++ b/ch06/02_additional-experiments/gpt_download.py @@ -0,0 +1,99 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + + +import os +import requests +import json +import numpy as np +import tensorflow as tf +from tqdm import tqdm + + +def download_and_load_gpt2(model_size, models_dir): + # Validate model size + allowed_sizes = ("124M", "355M", "774M", "1558M") + if model_size not in allowed_sizes: + raise ValueError(f"Model size not in {allowed_sizes}") + + # Define paths + model_dir = os.path.join(models_dir, model_size) + base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models" + filenames = [ + "checkpoint", "encoder.json", "hparams.json", + "model.ckpt.data-00000-of-00001", "model.ckpt.index", + "model.ckpt.meta", "vocab.bpe" + ] + + # Download files + os.makedirs(model_dir, exist_ok=True) + for filename in filenames: + file_url = os.path.join(base_url, model_size, filename) + file_path = os.path.join(model_dir, filename) + download_file(file_url, file_path) + + # Load settings and params + tf_ckpt_path = tf.train.latest_checkpoint(model_dir) + settings = json.load(open(os.path.join(model_dir, "hparams.json"))) + params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings) + + return settings, params + + +def download_file(url, destination): + # Send a GET request to download the file in streaming mode + response = requests.get(url, stream=True) + + # Get the total file size from headers, defaulting to 0 if not present + file_size = int(response.headers.get("content-length", 0)) + + # Check if file exists and has the same size + if os.path.exists(destination): + file_size_local = os.path.getsize(destination) + if file_size == file_size_local: + print(f"File already exists and is up-to-date: {destination}") + return + + # Define the block size for reading the file + block_size = 1024 # 1 Kilobyte + + # Initialize the progress bar with total file size + progress_bar_description = url.split("/")[-1] # Extract filename from URL + with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar: + # Open the destination file in binary write mode + with open(destination, "wb") as file: + # Iterate over the file data in chunks + for chunk in response.iter_content(block_size): + progress_bar.update(len(chunk)) # Update progress bar + file.write(chunk) # Write the chunk to the file + + +def load_gpt2_params_from_tf_ckpt(ckpt_path, settings): + # Initialize parameters dictionary with empty blocks for each layer + params = {"blocks": [{} for _ in range(settings["n_layer"])]} + + # Iterate over each variable in the checkpoint + for name, _ in tf.train.list_variables(ckpt_path): + # Load the variable and remove singleton dimensions + variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name)) + + # Process the variable name to extract relevant parts + variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix + + # Identify the target dictionary for the variable + target_dict = params + if variable_name_parts[0].startswith("h"): + layer_number = int(variable_name_parts[0][1:]) + target_dict = params["blocks"][layer_number] + + # Recursively access or create nested dictionaries + for key in variable_name_parts[1:-1]: + target_dict = target_dict.setdefault(key, {}) + + # Assign the variable array to the last key + last_key = variable_name_parts[-1] + target_dict[last_key] = variable_array + + return params diff --git a/ch06/02_additional-experiments/previous_chapters.py b/ch06/02_additional-experiments/previous_chapters.py new file mode 100644 index 0000000..e794f9b --- /dev/null +++ b/ch06/02_additional-experiments/previous_chapters.py @@ -0,0 +1,345 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch +# +# This file collects all the relevant code that we covered thus far +# throughout Chapters 2-5. +# This file can be run as a standalone script. + +import numpy as np +import tiktoken +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader + +##################################### +# Chapter 2 +##################################### + + +class GPTDatasetV1(Dataset): + def __init__(self, txt, tokenizer, max_length, stride): + self.tokenizer = tokenizer + self.input_ids = [] + self.target_ids = [] + + # Tokenize the entire text + token_ids = tokenizer.encode(txt) + + # Use a sliding window to chunk the book into overlapping sequences of max_length + for i in range(0, len(token_ids) - max_length, stride): + input_chunk = token_ids[i:i + max_length] + target_chunk = token_ids[i + 1: i + max_length + 1] + self.input_ids.append(torch.tensor(input_chunk)) + self.target_ids.append(torch.tensor(target_chunk)) + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return self.input_ids[idx], self.target_ids[idx] + + +def create_dataloader_v1(txt, batch_size=4, max_length=256, + stride=128, shuffle=True, drop_last=True): + # Initialize the tokenizer + tokenizer = tiktoken.get_encoding("gpt2") + + # Create dataset + dataset = GPTDatasetV1(txt, tokenizer, max_length, stride) + + # Create dataloader + dataloader = DataLoader( + dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) + + return dataloader + + +##################################### +# Chapter 3 +##################################### +class MultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by n_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) + + def forward(self, x): + b, num_tokens, d_in = x.shape + + keys = self.W_key(x) # Shape: (b, num_tokens, d_out) + queries = self.W_query(x) + values = self.W_value(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) + values = values.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys = keys.transpose(1, 2) + queries = queries.transpose(1, 2) + values = values.transpose(1, 2) + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + + # Original mask truncated to the number of tokens and converted to boolean + mask_bool = self.mask.bool()[:num_tokens, :num_tokens] + + # Use the mask to fill attention scores + attn_scores.masked_fill_(mask_bool, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.reshape(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + +##################################### +# Chapter 4 +##################################### +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), + GELU(), + nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), + ) + + def forward(self, x): + return self.layers(x) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + context_length=cfg["context_length"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"]) + self.ff = FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_resid = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + x = self.drop_resid(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = self.drop_resid(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + self.trf_blocks = nn.Sequential( + *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + x = self.trf_blocks(x) + x = self.final_norm(x) + logits = self.out_head(x) + return logits + + +def generate_text_simple(model, idx, max_new_tokens, context_size): + # idx is (B, T) array of indices in the current context + for _ in range(max_new_tokens): + + # Crop current context if it exceeds the supported context size + # E.g., if LLM supports only 5 tokens, and the context size is 10 + # then only the last 5 tokens are used as context + idx_cond = idx[:, -context_size:] + + # Get the predictions + with torch.no_grad(): + logits = model(idx_cond) + + # Focus only on the last time step + # (batch, n_token, vocab_size) becomes (batch, vocab_size) + logits = logits[:, -1, :] + + # Get the idx of the vocab entry with the highest logits value + idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1) + + # Append sampled index to the running sequence + idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1) + + return idx + + +##################################### +# Chapter 5 +##################################### +def assign(left, right): + if left.shape != right.shape: + raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}") + return torch.nn.Parameter(torch.tensor(right)) + + +def load_weights_into_gpt(gpt, params): + gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe']) + gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte']) + + for b in range(len(params["blocks"])): + q_w, k_w, v_w = np.split( + (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1) + gpt.trf_blocks[b].att.W_query.weight = assign( + gpt.trf_blocks[b].att.W_query.weight, q_w.T) + gpt.trf_blocks[b].att.W_key.weight = assign( + gpt.trf_blocks[b].att.W_key.weight, k_w.T) + gpt.trf_blocks[b].att.W_value.weight = assign( + gpt.trf_blocks[b].att.W_value.weight, v_w.T) + + q_b, k_b, v_b = np.split( + (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1) + gpt.trf_blocks[b].att.W_query.bias = assign( + gpt.trf_blocks[b].att.W_query.bias, q_b) + gpt.trf_blocks[b].att.W_key.bias = assign( + gpt.trf_blocks[b].att.W_key.bias, k_b) + gpt.trf_blocks[b].att.W_value.bias = assign( + gpt.trf_blocks[b].att.W_value.bias, v_b) + + gpt.trf_blocks[b].att.out_proj.weight = assign( + gpt.trf_blocks[b].att.out_proj.weight, + params["blocks"][b]["attn"]["c_proj"]["w"].T) + gpt.trf_blocks[b].att.out_proj.bias = assign( + gpt.trf_blocks[b].att.out_proj.bias, + params["blocks"][b]["attn"]["c_proj"]["b"]) + + gpt.trf_blocks[b].ff.layers[0].weight = assign( + gpt.trf_blocks[b].ff.layers[0].weight, + params["blocks"][b]["mlp"]["c_fc"]["w"].T) + gpt.trf_blocks[b].ff.layers[0].bias = assign( + gpt.trf_blocks[b].ff.layers[0].bias, + params["blocks"][b]["mlp"]["c_fc"]["b"]) + gpt.trf_blocks[b].ff.layers[2].weight = assign( + gpt.trf_blocks[b].ff.layers[2].weight, + params["blocks"][b]["mlp"]["c_proj"]["w"].T) + gpt.trf_blocks[b].ff.layers[2].bias = assign( + gpt.trf_blocks[b].ff.layers[2].bias, + params["blocks"][b]["mlp"]["c_proj"]["b"]) + + gpt.trf_blocks[b].norm1.scale = assign( + gpt.trf_blocks[b].norm1.scale, + params["blocks"][b]["ln_1"]["g"]) + gpt.trf_blocks[b].norm1.shift = assign( + gpt.trf_blocks[b].norm1.shift, + params["blocks"][b]["ln_1"]["b"]) + gpt.trf_blocks[b].norm2.scale = assign( + gpt.trf_blocks[b].norm2.scale, + params["blocks"][b]["ln_2"]["g"]) + gpt.trf_blocks[b].norm2.shift = assign( + gpt.trf_blocks[b].norm2.shift, + params["blocks"][b]["ln_2"]["b"]) + + gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"]) + gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"]) + gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"]) + + +def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): + # For-loop is the same as before: Get logits, and only focus on last time step + for _ in range(max_new_tokens): + idx_cond = idx[:, -context_size:] + with torch.no_grad(): + logits = model(idx_cond) + logits = logits[:, -1, :] + + # New: Filter logits with top_k sampling + if top_k is not None: + # Keep only top_k values + top_logits, _ = torch.topk(logits, top_k) + min_val = top_logits[:, -1] + logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) + + # New: Apply temperature scaling + if temperature > 0.0: + logits = logits / temperature + + # Apply softmax to get probabilities + probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) + + # Sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) + + # Otherwise same as before: get idx of the vocab entry with the highest logits value + else: + idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) + + # Same as before: append sampled index to the running sequence + idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) + + return idx