diff --git a/ch07/01_main-chapter-code/README.md b/ch07/01_main-chapter-code/README.md index c84be1c..f3f71d3 100644 --- a/ch07/01_main-chapter-code/README.md +++ b/ch07/01_main-chapter-code/README.md @@ -8,4 +8,50 @@ ### Optional Code -- [load-finetuned-model.ipynb](load-finetuned-model.ipynb) is a standalone Jupyter notebook to load the instruction finetuned model we created in this chapter \ No newline at end of file +- [load-finetuned-model.ipynb](load-finetuned-model.ipynb) is a standalone Jupyter notebook to load the instruction finetuned model we created in this chapter + +- [gpt-instruction-finetuning.py](gpt-instruction-finetuning.py) is a standalone Python script to instruction finetune the model as described in the main chapter + +Usage: + +```bash +python gpt-instruction-finetuning.py +``` + +``` +matplotlib version: 3.9.0 +tiktoken version: 0.7.0 +torch version: 2.3.1 +tqdm version: 4.66.4 +tensorflow version: 2.16.1 +-------------------------------------------------- +Training set length: 935 +Validation set length: 55 +Test set length: 110 +-------------------------------------------------- +Device: cpu +File already exists and is up-to-date: gpt2/355M/checkpoint +File already exists and is up-to-date: gpt2/355M/encoder.json +File already exists and is up-to-date: gpt2/355M/hparams.json +File already exists and is up-to-date: gpt2/355M/model.ckpt.data-00000-of-00001 +File already exists and is up-to-date: gpt2/355M/model.ckpt.index +File already exists and is up-to-date: gpt2/355M/model.ckpt.meta +File already exists and is up-to-date: gpt2/355M/vocab.bpe +Loaded model: gpt2-medium (355M) +-------------------------------------------------- +Initial losses + Training loss: 3.839039182662964 + Validation loss: 3.7619192123413088 +Ep 1 (Step 000000): Train loss 2.611, Val loss 2.668 +Ep 1 (Step 000005): Train loss 1.161, Val loss 1.131 +Ep 1 (Step 000010): Train loss 0.939, Val loss 0.973 +... +Training completed in 15.66 minutes. +Plot saved as loss-plot-standalone.pdf +-------------------------------------------------- +Generating responses +100%|██████████████████████████████████████████████████████████████████████████| 110/110 [06:57<00:00, 3.80s/it] +Responses saved as instruction-data-with-response-standalone.json +Model saved as gpt2-medium355M-sft-standalone.pth +``` + diff --git a/ch07/01_main-chapter-code/ch07.ipynb b/ch07/01_main-chapter-code/ch07.ipynb index a36951f..8f85656 100644 --- a/ch07/01_main-chapter-code/ch07.ipynb +++ b/ch07/01_main-chapter-code/ch07.ipynb @@ -426,7 +426,7 @@ "outputs": [], "source": [ "train_portion = int(len(data) * 0.85) # 85% for training\n", - "test_portion = int(len(data) * 0.1) # 10% for testing\n", + "test_portion = int(len(data) * 0.1) # 10% for testing\n", "val_portion = len(data) - train_portion - test_portion # Remaining 5% for validation\n", "\n", "train_data = data[:train_portion]\n", @@ -1166,7 +1166,8 @@ " batch_size=batch_size,\n", " collate_fn=customized_collate_fn,\n", " shuffle=True,\n", - " drop_last=True\n", + " drop_last=True,\n", + " num_workers=num_workers\n", ")" ] }, @@ -1185,7 +1186,8 @@ " batch_size=batch_size,\n", " collate_fn=customized_collate_fn,\n", " shuffle=False,\n", - " drop_last=False\n", + " drop_last=False,\n", + " num_workers=num_workers\n", ")\n", "\n", "test_dataset = InstructionDataset(test_data, tokenizer)\n", @@ -1194,7 +1196,8 @@ " batch_size=batch_size,\n", " collate_fn=customized_collate_fn,\n", " shuffle=False,\n", - " drop_last=False\n", + " drop_last=False,\n", + " num_workers=num_workers\n", ")" ] }, diff --git a/ch07/01_main-chapter-code/gpt-instruction-finetuning.py b/ch07/01_main-chapter-code/gpt-instruction-finetuning.py new file mode 100644 index 0000000..686ca1d --- /dev/null +++ b/ch07/01_main-chapter-code/gpt-instruction-finetuning.py @@ -0,0 +1,307 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch +# +# A minimal instruction finetuning file based on the code in chapter 7 + +from functools import partial +from importlib.metadata import version +import json +import os +import re +import time +import urllib + +import matplotlib.pyplot as plt +import tiktoken +import torch +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm + +# Import from local files in this folder +from gpt_download import download_and_load_gpt2 +from previous_chapters import ( + calc_loss_loader, + generate, + GPTModel, + load_weights_into_gpt, + text_to_token_ids, + train_model_simple, + token_ids_to_text +) + + +class InstructionDataset(Dataset): + def __init__(self, data, tokenizer): + self.data = data + + # Pre-tokenize texts + self.encoded_texts = [] + for entry in data: + instruction_plus_input = format_input(entry) + response_text = f"\n\n### Response:\n{entry['output']}" + full_text = instruction_plus_input + response_text + self.encoded_texts.append( + tokenizer.encode(full_text) + ) + + def __getitem__(self, index): + return self.encoded_texts[index] + + def __len__(self): + return len(self.data) + + +def custom_collate_fn( + batch, + pad_token_id=50256, + ignore_index=-100, + allowed_max_length=None, + device="cpu" +): + # Find the longest sequence in the batch + batch_max_length = max(len(item)+1 for item in batch) + + # Pad and prepare inputs and targets + inputs_lst, targets_lst = [], [] + + for item in batch: + new_item = item.copy() + # Add an <|endoftext|> token + new_item += [pad_token_id] + # Pad sequences to max_length + padded = new_item + [pad_token_id] * (batch_max_length - len(new_item)) + inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs + targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets + + # New: Replace all but the first padding tokens in targets by ignore_index + mask = targets == pad_token_id + indices = torch.nonzero(mask).squeeze() + if indices.numel() > 1: + targets[indices[1:]] = ignore_index + + # New: Optionally truncate to maximum sequence length + if allowed_max_length is not None: + inputs = inputs[:allowed_max_length] + targets = targets[:allowed_max_length] + + inputs_lst.append(inputs) + targets_lst.append(targets) + + # Convert list of inputs and targets to tensors and transfer to target device + inputs_tensor = torch.stack(inputs_lst).to(device) + targets_tensor = torch.stack(targets_lst).to(device) + + return inputs_tensor, targets_tensor + + +def download_and_load_file(file_path, url): + + if not os.path.exists(file_path): + with urllib.request.urlopen(url) as response: + text_data = response.read().decode("utf-8") + with open(file_path, "w", encoding="utf-8") as file: + file.write(text_data) + else: + with open(file_path, "r", encoding="utf-8") as file: + text_data = file.read() + + with open(file_path, "r") as file: + data = json.load(file) + + return data + + +def format_input(entry): + instruction_text = ( + f"Below is an instruction that describes a task. " + f"Write a response that appropriately completes the request." + f"\n\n### Instruction:\n{entry['instruction']}" + ) + + input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else "" + + return instruction_text + input_text + + +def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses): + fig, ax1 = plt.subplots(figsize=(12, 6)) + + # Plot training and validation loss against epochs + ax1.plot(epochs_seen, train_losses, label="Training loss") + ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss") + ax1.set_xlabel("Epochs") + ax1.set_ylabel("Loss") + ax1.legend(loc="upper right") + + # Create a second x-axis for tokens seen + ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis + ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks + ax2.set_xlabel("Tokens seen") + + fig.tight_layout() # Adjust layout to make room + plot_name = "loss-plot-standalone.pdf" + print(f"Plot saved as {plot_name}") + plt.savefig(plot_name) + # plt.show() + + +def main(): + ####################################### + # Print package versions + ####################################### + print() + pkgs = [ + "matplotlib", # Plotting library + "tiktoken", # Tokenizer + "torch", # Deep learning library + "tqdm", # Progress bar + "tensorflow", # For OpenAI's pretrained weights + ] + for p in pkgs: + print(f"{p} version: {version(p)}") + print(50*"-") + + ####################################### + # Download and prepare dataset + ####################################### + file_path = "instruction-data.json" + url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json" + data = download_and_load_file(file_path, url) + + train_portion = int(len(data) * 0.85) # 85% for training + test_portion = int(len(data) * 0.1) # 10% for testing + + train_data = data[:train_portion] + test_data = data[train_portion:train_portion + test_portion] + val_data = data[train_portion + test_portion:] + + print("Training set length:", len(train_data)) + print("Validation set length:", len(val_data)) + print("Test set length:", len(test_data)) + print(50*"-") + + tokenizer = tiktoken.get_encoding("gpt2") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Device:", device) + customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024) + + num_workers = 0 + batch_size = 8 + + torch.manual_seed(123) + + train_dataset = InstructionDataset(train_data, tokenizer) + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + collate_fn=customized_collate_fn, + shuffle=True, + drop_last=True, + num_workers=num_workers + ) + + val_dataset = InstructionDataset(val_data, tokenizer) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + collate_fn=customized_collate_fn, + shuffle=False, + drop_last=False, + num_workers=num_workers + ) + + ####################################### + # Load pretrained model + ####################################### + BASE_CONFIG = { + "vocab_size": 50257, # Vocabulary size + "context_length": 1024, # Context length + "drop_rate": 0.0, # Dropout rate + "qkv_bias": True # Query-key-value bias + } + + model_configs = { + "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12}, + "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16}, + "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20}, + "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25}, + } + + CHOOSE_MODEL = "gpt2-medium (355M)" + + BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) + + model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")") + settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") + + model = GPTModel(BASE_CONFIG) + load_weights_into_gpt(model, params) + model.eval() + model.to(device) + + print("Loaded model:", CHOOSE_MODEL) + print(50*"-") + + ####################################### + # Finetuning the model + ####################################### + print("Initial losses") + with torch.no_grad(): + train_loss = calc_loss_loader(train_loader, model, device, num_batches=5) + val_loss = calc_loss_loader(val_loader, model, device, num_batches=5) + + print(" Training loss:", train_loss) + print(" Validation loss:", val_loss) + + start_time = time.time() + optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1) + num_epochs = 2 + + train_losses, val_losses, tokens_seen = train_model_simple( + model, train_loader, val_loader, optimizer, device, + num_epochs=num_epochs, eval_freq=5, eval_iter=5, + start_context=format_input(val_data[0]), tokenizer=tokenizer + ) + + end_time = time.time() + execution_time_minutes = (end_time - start_time) / 60 + print(f"Training completed in {execution_time_minutes:.2f} minutes.") + + epochs_tensor = torch.linspace(0, num_epochs, len(train_losses)) + plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses) + print(50*"-") + + ####################################### + # Saving results + ####################################### + print("Evaluating models") + for i, entry in tqdm(enumerate(test_data), total=len(test_data)): + + input_text = format_input(entry) + + token_ids = generate( + model=model, + idx=text_to_token_ids(input_text, tokenizer).to(device), + max_new_tokens=256, + context_size=BASE_CONFIG["context_length"], + eos_id=50256 + ) + generated_text = token_ids_to_text(token_ids, tokenizer) + response_text = generated_text[len(input_text):].replace("### Response:", "").strip() + + test_data[i]["model_response"] = response_text + + test_data_path = "instruction-data-with-response-standalone.json" + with open(test_data_path, "w") as file: + json.dump(test_data, file, indent=4) # "indent" for pretty-printing + print(f"Responses saved as {test_data_path}") + + file_name = f"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft-standalone.pth" + torch.save(model.state_dict(), file_name) + print(f"Model saved as {file_name}") + + +if __name__ == "__main__": + main()