From 88176a82eb83d0ae38aef6b56956c78b1948f19d Mon Sep 17 00:00:00 2001 From: rasbt Date: Sun, 12 May 2024 18:27:50 -0500 Subject: [PATCH] chapter 06 summary file --- README.md | 2 +- .../gpt-class-finetune.py | 381 ++++++++++++++++++ 2 files changed, 382 insertions(+), 1 deletion(-) create mode 100644 ch06/01_main-chapter-code/gpt-class-finetune.py diff --git a/README.md b/README.md index 96ba1dc..f74e3d0 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Alternatively, you can view this and other files on GitHub at [https://github.co | Ch 3: Coding Attention Mechanisms | - [ch03.ipynb](ch03/01_main-chapter-code/ch03.ipynb)
- [multihead-attention.ipynb](ch03/01_main-chapter-code/multihead-attention.ipynb) (summary)
- [exercise-solutions.ipynb](ch03/01_main-chapter-code/exercise-solutions.ipynb)| [./ch03](./ch03) | | Ch 4: Implementing a GPT Model from Scratch | - [ch04.ipynb](ch04/01_main-chapter-code/ch04.ipynb)
- [gpt.py](ch04/01_main-chapter-code/gpt.py) (summary)
- [exercise-solutions.ipynb](ch04/01_main-chapter-code/exercise-solutions.ipynb) | [./ch04](./ch04) | | Ch 5: Pretraining on Unlabeled Data | - [ch05.ipynb](ch05/01_main-chapter-code/ch05.ipynb)
- [gpt_train.py](ch05/01_main-chapter-code/gpt_train.py) (summary)
- [gpt_generate.py](ch05/01_main-chapter-code/gpt_generate.py) (summary)
- [exercise-solutions.ipynb](ch05/01_main-chapter-code/exercise-solutions.ipynb) | [./ch05](./ch05) | -| Ch 6: Finetuning for Text Classification | - [ch06.ipynb](ch06/01_main-chapter-code/ch06.ipynb) | [./ch06](./ch06) | +| Ch 6: Finetuning for Text Classification | - [ch06.ipynb](ch06/01_main-chapter-code/ch06.ipynb)
- [gpt-class-finetune.py](ch06/01_main-chapter-code/gpt-class-finetune.py) | [./ch06](./ch06) | | Ch 7: Finetuning with Human Feedback | Q2 2024 | ... | | Appendix A: Introduction to PyTorch | - [code-part1.ipynb](appendix-A/01_main-chapter-code/code-part1.ipynb)
- [code-part2.ipynb](appendix-A/01_main-chapter-code/code-part2.ipynb)
- [DDP-script.py](appendix-A/01_main-chapter-code/DDP-script.py)
- [exercise-solutions.ipynb](appendix-A/01_main-chapter-code/exercise-solutions.ipynb) | [./appendix-A](./appendix-A) | | Appendix B: References and Further Reading | No code | - | diff --git a/ch06/01_main-chapter-code/gpt-class-finetune.py b/ch06/01_main-chapter-code/gpt-class-finetune.py new file mode 100644 index 0000000..4adbbe8 --- /dev/null +++ b/ch06/01_main-chapter-code/gpt-class-finetune.py @@ -0,0 +1,381 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# This is a summary file containing the main takeaways from chapter 6. + +import urllib.request +import zipfile +import os +from pathlib import Path +import time + +import matplotlib.pyplot as plt +import pandas as pd +import tiktoken +import torch +from torch.utils.data import Dataset, DataLoader + +from gpt_download import download_and_load_gpt2 +from previous_chapters import GPTModel, load_weights_into_gpt + + +def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path): + if data_file_path.exists(): + print(f"{data_file_path} already exists. Skipping download and extraction.") + return + + # Downloading the file + with urllib.request.urlopen(url) as response: + with open(zip_path, "wb") as out_file: + out_file.write(response.read()) + + # Unzipping the file + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extracted_path) + + # Add .tsv file extension + original_file_path = Path(extracted_path) / "SMSSpamCollection" + os.rename(original_file_path, data_file_path) + print(f"File downloaded and saved as {data_file_path}") + + +def create_balanced_dataset(df): + # Count the instances of "spam" + num_spam = df[df["Label"] == "spam"].shape[0] + + # Randomly sample "ham" instances to match the number of "spam" instances + ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123) + + # Combine ham "subset" with "spam" + balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]]) + + return balanced_df + + +def random_split(df, train_frac, validation_frac): + # Shuffle the entire DataFrame + df = df.sample(frac=1, random_state=123).reset_index(drop=True) + + # Calculate split indices + train_end = int(len(df) * train_frac) + validation_end = train_end + int(len(df) * validation_frac) + + # Split the DataFrame + train_df = df[:train_end] + validation_df = df[train_end:validation_end] + test_df = df[validation_end:] + + return train_df, validation_df, test_df + + +class SpamDataset(Dataset): + def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256): + self.data = pd.read_csv(csv_file) + + # Pre-tokenize texts + self.encoded_texts = [ + tokenizer.encode(text) for text in self.data["Text"] + ] + + if max_length is None: + self.max_length = self._longest_encoded_length() + else: + self.max_length = max_length + # Truncate sequences if they are longer than max_length + self.encoded_texts = [ + encoded_text[:self.max_length] + for encoded_text in self.encoded_texts + ] + + # Pad sequences to the longest sequence + self.encoded_texts = [ + encoded_text + [pad_token_id] * (self.max_length - len(encoded_text)) + for encoded_text in self.encoded_texts + ] + + def __getitem__(self, index): + encoded = self.encoded_texts[index] + label = self.data.iloc[index]["Label"] + return ( + torch.tensor(encoded, dtype=torch.long), + torch.tensor(label, dtype=torch.long) + ) + + def __len__(self): + return len(self.data) + + def _longest_encoded_length(self): + max_length = 0 + for encoded_text in self.encoded_texts: + encoded_length = len(encoded_text) + if encoded_length > max_length: + max_length = encoded_length + return max_length + + +def calc_accuracy_loader(data_loader, model, device, num_batches=None): + model.eval() + correct_predictions, num_examples = 0, 0 + + if num_batches is None: + num_batches = len(data_loader) + else: + num_batches = min(num_batches, len(data_loader)) + for i, (input_batch, target_batch) in enumerate(data_loader): + if i < num_batches: + input_batch, target_batch = input_batch.to(device), target_batch.to(device) + + with torch.no_grad(): + logits = model(input_batch)[:, -1, :] # Logits of last output token + predicted_labels = torch.argmax(logits, dim=-1) + + num_examples += predicted_labels.shape[0] + correct_predictions += (predicted_labels == target_batch).sum().item() + else: + break + return correct_predictions / num_examples + + +def calc_loss_batch(input_batch, target_batch, model, device): + input_batch, target_batch = input_batch.to(device), target_batch.to(device) + logits = model(input_batch)[:, -1, :] # Logits of last output token + loss = torch.nn.functional.cross_entropy(logits, target_batch) + return loss + + +def calc_loss_loader(data_loader, model, device, num_batches=None): + total_loss = 0. + if len(data_loader) == 0: + return float("nan") + elif num_batches is None: + num_batches = len(data_loader) + else: + num_batches = min(num_batches, len(data_loader)) + for i, (input_batch, target_batch) in enumerate(data_loader): + if i < num_batches: + loss = calc_loss_batch(input_batch, target_batch, model, device) + total_loss += loss.item() + else: + break + return total_loss / num_batches + + +def evaluate_model(model, train_loader, val_loader, device, eval_iter): + model.eval() + with torch.no_grad(): + train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) + val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) + model.train() + return train_loss, val_loss + + +def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, + eval_freq, eval_iter, tokenizer): + # Initialize lists to track losses and tokens seen + train_losses, val_losses, train_accs, val_accs = [], [], [], [] + examples_seen, global_step = 0, -1 + + # Main training loop + for epoch in range(num_epochs): + model.train() # Set model to training mode + + for input_batch, target_batch in train_loader: + optimizer.zero_grad() # Reset loss gradients from previous epoch + loss = calc_loss_batch(input_batch, target_batch, model, device) + loss.backward() # Calculate loss gradients + optimizer.step() # Update model weights using loss gradients + examples_seen += input_batch.shape[0] # New: track examples instead of tokens + global_step += 1 + + # Optional evaluation step + if global_step % eval_freq == 0: + train_loss, val_loss = evaluate_model( + model, train_loader, val_loader, device, eval_iter) + train_losses.append(train_loss) + val_losses.append(val_loss) + print(f"Ep {epoch+1} (Step {global_step:06d}): " + f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") + + # Calculate accuracy after each epoch + train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter) + val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter) + print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") + print(f"Validation accuracy: {val_accuracy*100:.2f}%") + train_accs.append(train_accuracy) + val_accs.append(val_accuracy) + + return train_losses, val_losses, train_accs, val_accs, examples_seen + + +def plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"): + fig, ax1 = plt.subplots(figsize=(5, 3)) + + # Plot training and validation loss against epochs + ax1.plot(epochs_seen, train_values, label=f"Training {label}") + ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}") + ax1.set_xlabel("Epochs") + ax1.set_ylabel(label.capitalize()) + ax1.legend() + + # Create a second x-axis for tokens seen + ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis + ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks + ax2.set_xlabel("Examples seen") + + fig.tight_layout() # Adjust layout to make room + plt.savefig(f"{label}-plot.pdf") + plt.show() + + +if __name__ == "__main__": + + ######################################## + # Download and prepare dataset + ######################################## + + url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip" + zip_path = "sms_spam_collection.zip" + extracted_path = "sms_spam_collection" + data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv" + + download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) + df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"]) + balanced_df = create_balanced_dataset(df) + balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1}) + + train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1) + train_df.to_csv("train.csv", index=None) + validation_df.to_csv("validation.csv", index=None) + test_df.to_csv("test.csv", index=None) + + ######################################## + # Create data loaders + ######################################## + tokenizer = tiktoken.get_encoding("gpt2") + + train_dataset = SpamDataset( + csv_file="train.csv", + max_length=None, + tokenizer=tokenizer + ) + + val_dataset = SpamDataset( + csv_file="validation.csv", + max_length=train_dataset.max_length, + tokenizer=tokenizer + ) + + test_dataset = SpamDataset( + csv_file="test.csv", + max_length=train_dataset.max_length, + tokenizer=tokenizer + ) + + num_workers = 0 + batch_size = 8 + + torch.manual_seed(123) + + train_loader = DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + drop_last=True, + ) + + val_loader = DataLoader( + dataset=val_dataset, + batch_size=batch_size, + num_workers=num_workers, + drop_last=False, + ) + + test_loader = DataLoader( + dataset=test_dataset, + batch_size=batch_size, + num_workers=num_workers, + drop_last=False, + ) + + ######################################## + # Load pretrained model + ######################################## + + CHOOSE_MODEL = "gpt2-small (124M)" + INPUT_PROMPT = "Every effort moves" + + BASE_CONFIG = { + "vocab_size": 50257, # Vocabulary size + "context_length": 1024, # Context length + "drop_rate": 0.0, # Dropout rate + "qkv_bias": True # Query-key-value bias + } + + model_configs = { + "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12}, + "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16}, + "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20}, + "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25}, + } + + BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) + + model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")") + settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") + + model = GPTModel(BASE_CONFIG) + load_weights_into_gpt(model, params) + model.eval() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + ######################################## + # Modify and pretrained model + ######################################## + + for param in model.parameters(): + param.requires_grad = False + + torch.manual_seed(123) + + num_classes = 2 + model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes) + + for param in model.trf_blocks[-1].parameters(): + param.requires_grad = True + + for param in model.final_norm.parameters(): + param.requires_grad = True + + ######################################## + # Finetune modified model + ######################################## + + start_time = time.time() + torch.manual_seed(123) + + optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1) + + num_epochs = 5 + train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( + model, train_loader, val_loader, optimizer, device, + num_epochs=num_epochs, eval_freq=50, eval_iter=5, + tokenizer=tokenizer + ) + + end_time = time.time() + execution_time_minutes = (end_time - start_time) / 60 + print(f"Training completed in {execution_time_minutes:.2f} minutes.") + + ######################################## + # Plot results + ######################################## + + epochs_tensor = torch.linspace(0, num_epochs, len(train_losses)) + examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses)) + + plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)