mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-11 18:22:01 +00:00
Merge pull request #69 from rasbt/pretraining-on-proj-gutenberg
Pretraining on Project Gutenberg
This commit is contained in:
commit
0b66c55950
File diff suppressed because one or more lines are too long
121
ch05/03_bonus_pretraining_on_gutenberg/README.md
Normal file
121
ch05/03_bonus_pretraining_on_gutenberg/README.md
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
# Pretraining GPT on the Project Gutenberg Dataset
|
||||||
|
|
||||||
|
The code in this directory contains code for training a small GPT model on the free books provided by Project Gutenberg.
|
||||||
|
|
||||||
|
As the Project Gutenberg website states, "the vast majority of Project Gutenberg eBooks are in the public domain in the US."
|
||||||
|
|
||||||
|
Please read the [Project Gutenberg Permissions, Licensing and other Common Requests](https://www.gutenberg.org/policy/permission.html) page for more information about using the resources provided by Project Gutenberg.
|
||||||
|
|
||||||
|
|
||||||
|
## How to use this code
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 1) Download the dataset
|
||||||
|
|
||||||
|
As of this writing, this will require approximately 50 GB of disk space, but it may be more depending on how much Project Gutenberg grew since then.
|
||||||
|
|
||||||
|
Follow these steps to download the dataset:
|
||||||
|
|
||||||
|
|
||||||
|
1. `git clone https://github.com/pgcorpus/gutenberg.git`
|
||||||
|
|
||||||
|
2. `cd gutenberg`
|
||||||
|
|
||||||
|
3. `pip install -r requirements.txt`
|
||||||
|
|
||||||
|
4. `python get_data.py`
|
||||||
|
|
||||||
|
5. `cd ..`
|
||||||
|
|
||||||
|
|
||||||
|
### 2) Prepare the dataset
|
||||||
|
|
||||||
|
Next, run the `prepare_dataset.py` script, which concatenates the (as of this writing, 60,173) text files into fewer larger files so that they can be more efficiently transferred and accessed:
|
||||||
|
|
||||||
|
```
|
||||||
|
prepare_dataset.py \
|
||||||
|
--data_dir "gutenberg/data" \
|
||||||
|
--max_size_mb 500 \
|
||||||
|
--output_dir "gutenberg_preprocessed"
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Note that the produced files are stored in plaintext format and are not pre-tokenized for simplicity. However, you may want to update the codes to store the dataset in a pre-tokenized form to save computation time if you are planning to use the dataset more often or train for multiple epochs. See the *Design Decisions and Improvements* at the bottom of this page for more information.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You can choose smaller file sizes, for example, 50 MB. This will result in more files but might be useful for quicker pretraining runs on a small number of files for testing purposes.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 3) Run the pretraining script
|
||||||
|
|
||||||
|
You can run the pretraining script as follows. Note that the additional command line arguments are shown with the default values for illustration purposes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pretraining_simple.py \
|
||||||
|
--data_dir "gutenberg_preprocessed" \
|
||||||
|
--n_epochs 1 \
|
||||||
|
--batch_size 4 \
|
||||||
|
--output_dir model_checkpoints
|
||||||
|
```
|
||||||
|
|
||||||
|
The output will be formatted in the following way:
|
||||||
|
|
||||||
|
```
|
||||||
|
Total files: 3
|
||||||
|
Tokenizing file 1 of 3: data_small/combined_1.txt
|
||||||
|
Training ...
|
||||||
|
Ep 1 (Step 0): Train loss 9.694, Val loss 9.724
|
||||||
|
Ep 1 (Step 100): Train loss 6.672, Val loss 6.683
|
||||||
|
Ep 1 (Step 200): Train loss 6.543, Val loss 6.434
|
||||||
|
Ep 1 (Step 300): Train loss 5.772, Val loss 6.313
|
||||||
|
Ep 1 (Step 400): Train loss 5.547, Val loss 6.249
|
||||||
|
Ep 1 (Step 500): Train loss 6.182, Val loss 6.155
|
||||||
|
Ep 1 (Step 600): Train loss 5.742, Val loss 6.122
|
||||||
|
Ep 1 (Step 700): Train loss 6.309, Val loss 5.984
|
||||||
|
Ep 1 (Step 800): Train loss 5.435, Val loss 5.975
|
||||||
|
Ep 1 (Step 900): Train loss 5.582, Val loss 5.935
|
||||||
|
...
|
||||||
|
Ep 1 (Step 31900): Train loss 3.664, Val loss 3.946
|
||||||
|
Ep 1 (Step 32000): Train loss 3.493, Val loss 3.939
|
||||||
|
Ep 1 (Step 32100): Train loss 3.940, Val loss 3.961
|
||||||
|
Saved model_checkpoints/model_pg_32188.pth
|
||||||
|
Book processed 3h 46m 55s
|
||||||
|
Total time elapsed 3h 46m 55s
|
||||||
|
ETA for remaining books: 7h 33m 50s
|
||||||
|
Tokenizing file 2 of 3: data_small/combined_2.txt
|
||||||
|
Training ...
|
||||||
|
Ep 1 (Step 32200): Train loss 2.982, Val loss 4.094
|
||||||
|
Ep 1 (Step 32300): Train loss 3.920, Val loss 4.097
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> In practice, if you are using macOS or Linux, I recommend using the `tee` command to save the log outputs to a `log.txt` file in addition to printing them on the terminal:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -u pretraining_simple.py | tee log.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> Note that training on 1 of the ~500 Mb text files in the `gutenberg_preprocessed` folder will take approximately 4 hours on a V100 GPU.
|
||||||
|
> The folder contains 47 files and will take approximately 200 hours (more than 1 week) to complete. You may want to run it on a smaller number of files.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Design Decisions and Improvements
|
||||||
|
|
||||||
|
Note that this code focuses on keeping things simple and minimal for educational purposes. The code could be improved in the following ways to improve modeling performance and training efficiency:
|
||||||
|
|
||||||
|
1. Modify the `prepare_dataset.py` script to strip the Gutenberg boilerplate text from each book file.
|
||||||
|
2. Update the data preparation and loading utilities to pre-tokenize the dataset and save it in a tokenized form so that it doesn't have to be re-tokenized each time when calling the pretraining script.
|
||||||
|
3. Update the `train_model_simple` script by adding the features introduced in [Appendix D: Adding Bells and Whistles to the Training Loop](../../appendix-D/01_main-chapter-code/appendix-D.ipynb), namely, cosine decay, linear warmup, and gradient clipping.
|
||||||
|
4. Update the pretraining script to save the optimizer state (see section *5.4 Loading and saving weights in PyTorch* in chapter 5; [ch05.ipynb](../../ch05/01_main-chapter-code/ch05.ipynb)) and add the option to load an existing model and optimizer checkpoint and continue training if the training run was interrupted.
|
||||||
|
5. Add a more advanced logger (for example, Weights and Biases) to view the loss and validation curves live
|
||||||
|
6. Add distributed data parallelism (DDP) and train the model on multiple GPUs (see section *A.9.3 Training with multiple GPUs* in appendix A; [DDP-script.py](../../appendix-A/03_main-chapter-code/DDP-script.py)).
|
||||||
|
7. Swap the from scratch `MultiheadAttention` class in the `previous_chapter.py` script with the efficient `MHAPyTorchScaledDotProduct` class implemented in the [Efficient Multi-Head Attention Implementations](../../ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb) bonus section, which uses Flash Attention via PyTorch's `nn.functional.scaled_dot_product_attention` function.
|
||||||
|
|
66
ch05/03_bonus_pretraining_on_gutenberg/prepare_dataset.py
Normal file
66
ch05/03_bonus_pretraining_on_gutenberg/prepare_dataset.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Script that processes the Project Gutenberg files into fewer larger files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftext|>", fallback_encoding="latin1"):
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
os.makedirs(target_dir)
|
||||||
|
|
||||||
|
current_content = []
|
||||||
|
current_size = 0
|
||||||
|
file_counter = 1
|
||||||
|
|
||||||
|
for file_path in file_paths:
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
|
content = file.read()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# Attempt to read the file with a fallback encoding
|
||||||
|
print(f"Warning: UnicodeDecodeError encountered. Trying fallback encoding for {file_path}")
|
||||||
|
with open(file_path, "r", encoding=fallback_encoding) as file:
|
||||||
|
content = file.read()
|
||||||
|
|
||||||
|
estimated_size = len(content.encode("utf-8"))
|
||||||
|
|
||||||
|
if current_size + estimated_size > max_size_mb * 1024 * 1024:
|
||||||
|
target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt")
|
||||||
|
with open(target_file_path, "w", encoding="utf-8") as target_file:
|
||||||
|
target_file.write(separator.join(current_content))
|
||||||
|
file_counter += 1
|
||||||
|
current_content = [content]
|
||||||
|
current_size = estimated_size
|
||||||
|
else:
|
||||||
|
current_content.append(content)
|
||||||
|
current_size += estimated_size
|
||||||
|
|
||||||
|
if current_content:
|
||||||
|
target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt")
|
||||||
|
with open(target_file_path, "w", encoding="utf-8") as target_file:
|
||||||
|
target_file.write(separator.join(current_content))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="GPT Model Training Configuration")
|
||||||
|
|
||||||
|
parser.add_argument("--data_dir", type=str, default="gutenberg/data",
|
||||||
|
help="Directory containing the downloaded raw training data")
|
||||||
|
parser.add_argument("--max_size_mb", type=int, default=500,
|
||||||
|
help="The maximum file size for each concatenated file in megabytes")
|
||||||
|
parser.add_argument("--output_dir", type=str, default="gutenberg_preprocessed",
|
||||||
|
help="Directory where the preprocessed data will be saved")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
all_files = [os.path.join(path, name) for path, subdirs, files in os.walk(args.data_dir)
|
||||||
|
for name in files if name.endswith((".txt", ".txt.utf8")) and "raw" not in path]
|
||||||
|
|
||||||
|
target_dir = "path_to_your_large_files"
|
||||||
|
print(f"{len(all_files)} files to process.")
|
||||||
|
|
||||||
|
combine_files(all_files, args.output_dir)
|
212
ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py
Normal file
212
ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Script for pretraining a small GPT-2 124M parameter model
|
||||||
|
on books from Project Gutenberg.
|
||||||
|
|
||||||
|
Before running this script, make sure you downloaded and
|
||||||
|
processed the dataset as described in the README.md.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from previous_chapters import (
|
||||||
|
create_dataloader_v1,
|
||||||
|
GPTModel,
|
||||||
|
generate_and_print_sample,
|
||||||
|
calc_loss_batch,
|
||||||
|
evaluate_model,
|
||||||
|
plot_losses
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_text_file(file_path):
|
||||||
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
|
text_data = file.read()
|
||||||
|
return text_data
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride):
|
||||||
|
split_idx = int(train_ratio * len(text_data))
|
||||||
|
train_loader = create_dataloader_v1(
|
||||||
|
text_data[:split_idx],
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_length=max_length,
|
||||||
|
stride=stride,
|
||||||
|
drop_last=True,
|
||||||
|
shuffle=True
|
||||||
|
)
|
||||||
|
val_loader = create_dataloader_v1(
|
||||||
|
text_data[split_idx:],
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_length=max_length,
|
||||||
|
stride=stride,
|
||||||
|
drop_last=False,
|
||||||
|
shuffle=False
|
||||||
|
)
|
||||||
|
return train_loader, val_loader
|
||||||
|
|
||||||
|
|
||||||
|
def convert_time(seconds):
|
||||||
|
hours, rem = divmod(seconds, 3600)
|
||||||
|
minutes, seconds = divmod(rem, 60)
|
||||||
|
return int(hours), int(minutes), int(seconds)
|
||||||
|
|
||||||
|
|
||||||
|
def print_eta(start_time, book_start_time, index, total_files):
|
||||||
|
book_end_time = time.time() # End time of processing this book
|
||||||
|
elapsed_time = book_end_time - book_start_time
|
||||||
|
total_elapsed_time = book_end_time - start_time
|
||||||
|
books_remaining = total_files - index
|
||||||
|
average_time_per_book = total_elapsed_time / index
|
||||||
|
eta = average_time_per_book * books_remaining
|
||||||
|
|
||||||
|
book_h, book_m, book_s = convert_time(elapsed_time)
|
||||||
|
total_h, total_m, total_s = convert_time(total_elapsed_time)
|
||||||
|
eta_h, eta_m, eta_s = convert_time(eta)
|
||||||
|
|
||||||
|
print(f"Book processed {book_h}h {book_m}m {book_s}s"
|
||||||
|
f"\nTotal time elapsed {total_h}h {total_m}m {total_s}s"
|
||||||
|
f"\nETA for remaining books: {eta_h}h {eta_m}m {eta_s}s")
|
||||||
|
|
||||||
|
|
||||||
|
def train_model_simple(model, optimizer, device, n_epochs,
|
||||||
|
eval_freq, eval_iter, print_sample_iter, start_context,
|
||||||
|
output_dir, save_ckpt_freq,
|
||||||
|
batch_size=1024, train_ratio=0.90):
|
||||||
|
|
||||||
|
train_losses, val_losses, track_tokens_seen = [], [], []
|
||||||
|
tokens_seen = 0
|
||||||
|
global_step = -1
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for epoch in range(n_epochs):
|
||||||
|
|
||||||
|
# Iterate over the books in the training corpus
|
||||||
|
for index, file_path in enumerate(all_files, 1):
|
||||||
|
book_start_time = time.time()
|
||||||
|
text_data = read_text_file(file_path) + " <|endoftext|> "
|
||||||
|
print(f"Tokenizing file {index} of {total_files}: {file_path}")
|
||||||
|
|
||||||
|
# Initialize new data loaders for each book
|
||||||
|
train_loader, val_loader = create_dataloaders(
|
||||||
|
text_data,
|
||||||
|
train_ratio=train_ratio,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_length=GPT_CONFIG_124M["ctx_len"],
|
||||||
|
stride=GPT_CONFIG_124M["ctx_len"]
|
||||||
|
)
|
||||||
|
print(f"Training ...")
|
||||||
|
model.train()
|
||||||
|
for input_batch, target_batch in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
tokens_seen += input_batch.numel()
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
# Optional evaluation step
|
||||||
|
if global_step % eval_freq == 0:
|
||||||
|
train_loss, val_loss = evaluate_model(
|
||||||
|
model, train_loader, val_loader, device, eval_iter)
|
||||||
|
train_losses.append(train_loss)
|
||||||
|
val_losses.append(val_loss)
|
||||||
|
track_tokens_seen.append(tokens_seen)
|
||||||
|
print(f"Ep {epoch+1} (Step {global_step}): "
|
||||||
|
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
|
||||||
|
|
||||||
|
# Generate text passage
|
||||||
|
if index % print_sample_iter == 0:
|
||||||
|
generate_and_print_sample(
|
||||||
|
model, train_loader.dataset.tokenizer, device, start_context
|
||||||
|
)
|
||||||
|
|
||||||
|
if global_step % save_ckpt_freq:
|
||||||
|
file_name = output_dir / f"model_pg_{global_step}.pth"
|
||||||
|
torch.save(model.state_dict(), file_name)
|
||||||
|
print(f"Saved {file_name}")
|
||||||
|
|
||||||
|
print_eta(start_time, book_start_time, index, total_files)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
file_name = output_dir / f"model_pg_{global_step}_interrupted.pth"
|
||||||
|
torch.save(model.state_dict(), file_name)
|
||||||
|
print(f"Saved {file_name}")
|
||||||
|
|
||||||
|
return train_losses, val_losses, tokens_seen
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='GPT Model Training Configuration')
|
||||||
|
|
||||||
|
parser.add_argument('--data_dir', type=str, default='gutenberg/data',
|
||||||
|
help='Directory containing the training data')
|
||||||
|
parser.add_argument('--output_dir', type=str, default='model_checkpoints',
|
||||||
|
help='Directory where the model checkpoints will be saved')
|
||||||
|
parser.add_argument('--n_epochs', type=int, default=1,
|
||||||
|
help='Number of epochs to train the model')
|
||||||
|
parser.add_argument('--print_sample_iter', type=int, default=500,
|
||||||
|
help='Iterations between printing sample outputs')
|
||||||
|
parser.add_argument('--eval_freq', type=int, default=100,
|
||||||
|
help='Frequency of evaluations during training')
|
||||||
|
parser.add_argument('--save_ckpt_freq', type=int, default=100_000,
|
||||||
|
help='Frequency of saving model checkpoints during training')
|
||||||
|
parser.add_argument('--lr', type=float, default=5e-4,
|
||||||
|
help='Learning rate for the optimizer')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=4,
|
||||||
|
help='Batch size for training')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
GPT_CONFIG_124M = {
|
||||||
|
"vocab_size": 50257, # Vocabulary size
|
||||||
|
"ctx_len": 1024, # Context length
|
||||||
|
"emb_dim": 768, # Embedding dimension
|
||||||
|
"n_heads": 12, # Number of attention heads
|
||||||
|
"n_layers": 12, # Number of layers
|
||||||
|
"drop_rate": 0.1, # Dropout rate
|
||||||
|
"qkv_bias": False # Query-key-value bias
|
||||||
|
}
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
torch.manual_seed(123)
|
||||||
|
model = GPTModel(GPT_CONFIG_124M)
|
||||||
|
model.to(device)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1)
|
||||||
|
|
||||||
|
data_dir = args.data_dir
|
||||||
|
all_files = [os.path.join(path, name) for path, subdirs, files
|
||||||
|
in os.walk(data_dir) for name in files if name.endswith((".txt"))]
|
||||||
|
total_files = len(all_files)
|
||||||
|
|
||||||
|
if total_files == 0:
|
||||||
|
print("No training text files found. Make sure you "
|
||||||
|
"selected the correct input directory")
|
||||||
|
quit()
|
||||||
|
print("Total files:", total_files)
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
train_losses, val_losses, tokens_seen = train_model_simple(
|
||||||
|
model, optimizer, device,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
n_epochs=args.n_epochs,
|
||||||
|
eval_freq=args.eval_freq,
|
||||||
|
eval_iter=1,
|
||||||
|
print_sample_iter=args.print_sample_iter,
|
||||||
|
output_dir=output_dir,
|
||||||
|
save_ckpt_freq=args.save_ckpt_freq,
|
||||||
|
start_context="Every effort moves you",
|
||||||
|
)
|
||||||
|
|
||||||
|
epochs_tensor = torch.linspace(1, args.n_epochs, len(train_losses))
|
||||||
|
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), output_dir / "model_pg_final.pth")
|
||||||
|
print(f"Maximum GPU memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
|
313
ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py
Normal file
313
ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
# This file collects all the relevant code that we covered thus far
|
||||||
|
# throughout Chapters 2-4.
|
||||||
|
# This file can be run as a standalone script.
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
# Chapter 2
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
class GPTDatasetV1(Dataset):
|
||||||
|
def __init__(self, txt, tokenizer, max_length, stride):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.input_ids = []
|
||||||
|
self.target_ids = []
|
||||||
|
|
||||||
|
token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})
|
||||||
|
|
||||||
|
for i in range(0, len(token_ids) - max_length, stride):
|
||||||
|
input_chunk = token_ids[i:i + max_length]
|
||||||
|
target_chunk = token_ids[i + 1: i + max_length + 1]
|
||||||
|
self.input_ids.append(torch.tensor(input_chunk))
|
||||||
|
self.target_ids.append(torch.tensor(target_chunk))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.input_ids)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.input_ids[idx], self.target_ids[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||||
|
stride=128, shuffle=True, drop_last=True):
|
||||||
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
# Chapter 3
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
|
||||||
|
super().__init__()
|
||||||
|
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
|
||||||
|
|
||||||
|
self.d_out = d_out
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
|
||||||
|
|
||||||
|
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||||
|
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||||
|
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||||
|
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, num_tokens, d_in = x.shape
|
||||||
|
|
||||||
|
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
||||||
|
queries = self.W_query(x)
|
||||||
|
values = self.W_value(x)
|
||||||
|
|
||||||
|
# We implicitly split the matrix by adding a `num_heads` dimension
|
||||||
|
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
|
||||||
|
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||||
|
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||||
|
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
|
||||||
|
keys = keys.transpose(1, 2)
|
||||||
|
queries = queries.transpose(1, 2)
|
||||||
|
values = values.transpose(1, 2)
|
||||||
|
|
||||||
|
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||||
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||||
|
|
||||||
|
# Original mask truncated to the number of tokens and converted to boolean
|
||||||
|
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||||
|
|
||||||
|
# Use the mask to fill attention scores
|
||||||
|
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||||
|
|
||||||
|
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||||
|
attn_weights = self.dropout(attn_weights)
|
||||||
|
|
||||||
|
# Shape: (b, num_tokens, num_heads, head_dim)
|
||||||
|
context_vec = (attn_weights @ values).transpose(1, 2)
|
||||||
|
|
||||||
|
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
||||||
|
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
|
||||||
|
context_vec = self.out_proj(context_vec) # optional projection
|
||||||
|
|
||||||
|
return context_vec
|
||||||
|
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
# Chapter 4
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, emb_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = 1e-5
|
||||||
|
self.scale = nn.Parameter(torch.ones(emb_dim))
|
||||||
|
self.shift = nn.Parameter(torch.zeros(emb_dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mean = x.mean(dim=-1, keepdim=True)
|
||||||
|
var = x.var(dim=-1, keepdim=True, unbiased=False)
|
||||||
|
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
||||||
|
return self.scale * norm_x + self.shift
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return 0.5 * x * (1 + torch.tanh(
|
||||||
|
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
|
||||||
|
(x + 0.044715 * torch.pow(x, 3))
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
|
||||||
|
GELU(),
|
||||||
|
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
|
||||||
|
nn.Dropout(cfg["drop_rate"])
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
self.att = MultiHeadAttention(
|
||||||
|
d_in=cfg["emb_dim"],
|
||||||
|
d_out=cfg["emb_dim"],
|
||||||
|
block_size=cfg["ctx_len"],
|
||||||
|
num_heads=cfg["n_heads"],
|
||||||
|
dropout=cfg["drop_rate"],
|
||||||
|
qkv_bias=cfg["qkv_bias"])
|
||||||
|
self.ff = FeedForward(cfg)
|
||||||
|
self.norm1 = LayerNorm(cfg["emb_dim"])
|
||||||
|
self.norm2 = LayerNorm(cfg["emb_dim"])
|
||||||
|
self.drop_resid = nn.Dropout(cfg["drop_rate"])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Shortcut connection for attention block
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
|
||||||
|
x = self.drop_resid(x)
|
||||||
|
x = x + shortcut # Add the original input back
|
||||||
|
|
||||||
|
# Shortcut connection for feed-forward block
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.ff(x)
|
||||||
|
x = self.drop_resid(x)
|
||||||
|
x = x + shortcut # Add the original input back
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GPTModel(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
|
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
|
||||||
|
self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])
|
||||||
|
self.drop_emb = nn.Dropout(cfg["drop_rate"])
|
||||||
|
|
||||||
|
self.trf_blocks = nn.Sequential(
|
||||||
|
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||||
|
|
||||||
|
self.final_norm = LayerNorm(cfg["emb_dim"])
|
||||||
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
|
||||||
|
|
||||||
|
def forward(self, in_idx):
|
||||||
|
batch_size, seq_len = in_idx.shape
|
||||||
|
tok_embeds = self.tok_emb(in_idx)
|
||||||
|
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
|
||||||
|
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
|
||||||
|
x = self.drop_emb(x)
|
||||||
|
x = self.trf_blocks(x)
|
||||||
|
x = self.final_norm(x)
|
||||||
|
logits = self.out_head(x)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||||
|
# idx is (B, T) array of indices in the current context
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
|
||||||
|
# Crop current context if it exceeds the supported context size
|
||||||
|
# E.g., if LLM supports only 5 tokens, and the context size is 10
|
||||||
|
# then only the last 5 tokens are used as context
|
||||||
|
idx_cond = idx[:, -context_size:]
|
||||||
|
|
||||||
|
# Get the predictions
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(idx_cond)
|
||||||
|
|
||||||
|
# Focus only on the last time step
|
||||||
|
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
# Get the idx of the vocab entry with the highest logits value
|
||||||
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
|
||||||
|
|
||||||
|
# Append sampled index to the running sequence
|
||||||
|
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
|
||||||
|
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
# Chapter 5
|
||||||
|
####################################
|
||||||
|
|
||||||
|
|
||||||
|
def calc_loss_batch(input_batch, target_batch, model, device):
|
||||||
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||||
|
|
||||||
|
logits = model(input_batch)
|
||||||
|
logits = logits.view(-1, logits.size(-1))
|
||||||
|
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||||
|
total_loss, batches_seen = 0., 0.
|
||||||
|
if num_batches is None:
|
||||||
|
num_batches = len(data_loader)
|
||||||
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||||
|
if i < num_batches:
|
||||||
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
||||||
|
total_loss += loss.item()
|
||||||
|
batches_seen += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return total_loss / batches_seen
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
|
||||||
|
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
|
||||||
|
model.train()
|
||||||
|
return train_loss, val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def generate_and_print_sample(model, tokenizer, device, start_context):
|
||||||
|
model.eval()
|
||||||
|
context_size = model.pos_emb.weight.shape[0]
|
||||||
|
encoded = text_to_token_ids(start_context, tokenizer).to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
token_ids = generate_text_simple(model=model, idx=encoded,
|
||||||
|
max_new_tokens=50, context_size=context_size)
|
||||||
|
decoded_text = token_ids_to_text(token_ids, tokenizer)
|
||||||
|
print(decoded_text.replace("\n", " ")) # Compact print format
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir):
|
||||||
|
fig, ax1 = plt.subplots()
|
||||||
|
|
||||||
|
# Plot training and validation loss against epochs
|
||||||
|
ax1.plot(epochs_seen, train_losses, label="Training loss")
|
||||||
|
ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
|
||||||
|
ax1.set_xlabel("Epochs")
|
||||||
|
ax1.set_ylabel("Loss")
|
||||||
|
ax1.legend(loc="upper right")
|
||||||
|
|
||||||
|
# Create a second x-axis for tokens seen
|
||||||
|
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
|
||||||
|
ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
|
||||||
|
ax2.set_xlabel("Tokens seen")
|
||||||
|
|
||||||
|
fig.tight_layout() # Adjust layout to make room
|
||||||
|
plt.savefig(output_dir / "losses.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_token_ids(text, tokenizer):
|
||||||
|
encoded = tokenizer.encode(text)
|
||||||
|
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
|
||||||
|
return encoded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def token_ids_to_text(token_ids, tokenizer):
|
||||||
|
flat = token_ids.squeeze(0) # remove batch dimension
|
||||||
|
return tokenizer.decode(flat.tolist())
|
||||||
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user