Merge pull request #69 from rasbt/pretraining-on-proj-gutenberg

Pretraining on Project Gutenberg
This commit is contained in:
Sebastian Raschka 2024-03-13 08:38:33 -05:00 committed by GitHub
commit 0b66c55950
8 changed files with 1710 additions and 998 deletions

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,121 @@
# Pretraining GPT on the Project Gutenberg Dataset
The code in this directory contains code for training a small GPT model on the free books provided by Project Gutenberg.
As the Project Gutenberg website states, "the vast majority of Project Gutenberg eBooks are in the public domain in the US."
Please read the [Project Gutenberg Permissions, Licensing and other Common Requests](https://www.gutenberg.org/policy/permission.html) page for more information about using the resources provided by Project Gutenberg.
 
## How to use this code
 
### 1) Download the dataset
As of this writing, this will require approximately 50 GB of disk space, but it may be more depending on how much Project Gutenberg grew since then.
Follow these steps to download the dataset:
1. `git clone https://github.com/pgcorpus/gutenberg.git`
2. `cd gutenberg`
3. `pip install -r requirements.txt`
4. `python get_data.py`
5. `cd ..`
 
### 2) Prepare the dataset
Next, run the `prepare_dataset.py` script, which concatenates the (as of this writing, 60,173) text files into fewer larger files so that they can be more efficiently transferred and accessed:
```
prepare_dataset.py \
--data_dir "gutenberg/data" \
--max_size_mb 500 \
--output_dir "gutenberg_preprocessed"
```
> [!TIP]
> Note that the produced files are stored in plaintext format and are not pre-tokenized for simplicity. However, you may want to update the codes to store the dataset in a pre-tokenized form to save computation time if you are planning to use the dataset more often or train for multiple epochs. See the *Design Decisions and Improvements* at the bottom of this page for more information.
> [!TIP]
> You can choose smaller file sizes, for example, 50 MB. This will result in more files but might be useful for quicker pretraining runs on a small number of files for testing purposes.
 
### 3) Run the pretraining script
You can run the pretraining script as follows. Note that the additional command line arguments are shown with the default values for illustration purposes:
```bash
pretraining_simple.py \
--data_dir "gutenberg_preprocessed" \
--n_epochs 1 \
--batch_size 4 \
--output_dir model_checkpoints
```
The output will be formatted in the following way:
```
Total files: 3
Tokenizing file 1 of 3: data_small/combined_1.txt
Training ...
Ep 1 (Step 0): Train loss 9.694, Val loss 9.724
Ep 1 (Step 100): Train loss 6.672, Val loss 6.683
Ep 1 (Step 200): Train loss 6.543, Val loss 6.434
Ep 1 (Step 300): Train loss 5.772, Val loss 6.313
Ep 1 (Step 400): Train loss 5.547, Val loss 6.249
Ep 1 (Step 500): Train loss 6.182, Val loss 6.155
Ep 1 (Step 600): Train loss 5.742, Val loss 6.122
Ep 1 (Step 700): Train loss 6.309, Val loss 5.984
Ep 1 (Step 800): Train loss 5.435, Val loss 5.975
Ep 1 (Step 900): Train loss 5.582, Val loss 5.935
...
Ep 1 (Step 31900): Train loss 3.664, Val loss 3.946
Ep 1 (Step 32000): Train loss 3.493, Val loss 3.939
Ep 1 (Step 32100): Train loss 3.940, Val loss 3.961
Saved model_checkpoints/model_pg_32188.pth
Book processed 3h 46m 55s
Total time elapsed 3h 46m 55s
ETA for remaining books: 7h 33m 50s
Tokenizing file 2 of 3: data_small/combined_2.txt
Training ...
Ep 1 (Step 32200): Train loss 2.982, Val loss 4.094
Ep 1 (Step 32300): Train loss 3.920, Val loss 4.097
...
```
 
> [!TIP]
> In practice, if you are using macOS or Linux, I recommend using the `tee` command to save the log outputs to a `log.txt` file in addition to printing them on the terminal:
```bash
python -u pretraining_simple.py | tee log.txt
```
 
> [!WARNING]
> Note that training on 1 of the ~500 Mb text files in the `gutenberg_preprocessed` folder will take approximately 4 hours on a V100 GPU.
> The folder contains 47 files and will take approximately 200 hours (more than 1 week) to complete. You may want to run it on a smaller number of files.
 
## Design Decisions and Improvements
Note that this code focuses on keeping things simple and minimal for educational purposes. The code could be improved in the following ways to improve modeling performance and training efficiency:
1. Modify the `prepare_dataset.py` script to strip the Gutenberg boilerplate text from each book file.
2. Update the data preparation and loading utilities to pre-tokenize the dataset and save it in a tokenized form so that it doesn't have to be re-tokenized each time when calling the pretraining script.
3. Update the `train_model_simple` script by adding the features introduced in [Appendix D: Adding Bells and Whistles to the Training Loop](../../appendix-D/01_main-chapter-code/appendix-D.ipynb), namely, cosine decay, linear warmup, and gradient clipping.
4. Update the pretraining script to save the optimizer state (see section *5.4 Loading and saving weights in PyTorch* in chapter 5; [ch05.ipynb](../../ch05/01_main-chapter-code/ch05.ipynb)) and add the option to load an existing model and optimizer checkpoint and continue training if the training run was interrupted.
5. Add a more advanced logger (for example, Weights and Biases) to view the loss and validation curves live
6. Add distributed data parallelism (DDP) and train the model on multiple GPUs (see section *A.9.3 Training with multiple GPUs* in appendix A; [DDP-script.py](../../appendix-A/03_main-chapter-code/DDP-script.py)).
7. Swap the from scratch `MultiheadAttention` class in the `previous_chapter.py` script with the efficient `MHAPyTorchScaledDotProduct` class implemented in the [Efficient Multi-Head Attention Implementations](../../ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb) bonus section, which uses Flash Attention via PyTorch's `nn.functional.scaled_dot_product_attention` function.

View File

@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
"""
Script that processes the Project Gutenberg files into fewer larger files.
"""
import argparse
import os
def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftext|>", fallback_encoding="latin1"):
if not os.path.exists(target_dir):
os.makedirs(target_dir)
current_content = []
current_size = 0
file_counter = 1
for file_path in file_paths:
try:
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
except UnicodeDecodeError:
# Attempt to read the file with a fallback encoding
print(f"Warning: UnicodeDecodeError encountered. Trying fallback encoding for {file_path}")
with open(file_path, "r", encoding=fallback_encoding) as file:
content = file.read()
estimated_size = len(content.encode("utf-8"))
if current_size + estimated_size > max_size_mb * 1024 * 1024:
target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt")
with open(target_file_path, "w", encoding="utf-8") as target_file:
target_file.write(separator.join(current_content))
file_counter += 1
current_content = [content]
current_size = estimated_size
else:
current_content.append(content)
current_size += estimated_size
if current_content:
target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt")
with open(target_file_path, "w", encoding="utf-8") as target_file:
target_file.write(separator.join(current_content))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GPT Model Training Configuration")
parser.add_argument("--data_dir", type=str, default="gutenberg/data",
help="Directory containing the downloaded raw training data")
parser.add_argument("--max_size_mb", type=int, default=500,
help="The maximum file size for each concatenated file in megabytes")
parser.add_argument("--output_dir", type=str, default="gutenberg_preprocessed",
help="Directory where the preprocessed data will be saved")
args = parser.parse_args()
all_files = [os.path.join(path, name) for path, subdirs, files in os.walk(args.data_dir)
for name in files if name.endswith((".txt", ".txt.utf8")) and "raw" not in path]
target_dir = "path_to_your_large_files"
print(f"{len(all_files)} files to process.")
combine_files(all_files, args.output_dir)

View File

@ -0,0 +1,212 @@
# -*- coding: utf-8 -*-
"""
Script for pretraining a small GPT-2 124M parameter model
on books from Project Gutenberg.
Before running this script, make sure you downloaded and
processed the dataset as described in the README.md.
"""
import argparse
import os
from pathlib import Path
import time
import torch
from previous_chapters import (
create_dataloader_v1,
GPTModel,
generate_and_print_sample,
calc_loss_batch,
evaluate_model,
plot_losses
)
def read_text_file(file_path):
with open(file_path, "r", encoding="utf-8") as file:
text_data = file.read()
return text_data
def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride):
split_idx = int(train_ratio * len(text_data))
train_loader = create_dataloader_v1(
text_data[:split_idx],
batch_size=batch_size,
max_length=max_length,
stride=stride,
drop_last=True,
shuffle=True
)
val_loader = create_dataloader_v1(
text_data[split_idx:],
batch_size=batch_size,
max_length=max_length,
stride=stride,
drop_last=False,
shuffle=False
)
return train_loader, val_loader
def convert_time(seconds):
hours, rem = divmod(seconds, 3600)
minutes, seconds = divmod(rem, 60)
return int(hours), int(minutes), int(seconds)
def print_eta(start_time, book_start_time, index, total_files):
book_end_time = time.time() # End time of processing this book
elapsed_time = book_end_time - book_start_time
total_elapsed_time = book_end_time - start_time
books_remaining = total_files - index
average_time_per_book = total_elapsed_time / index
eta = average_time_per_book * books_remaining
book_h, book_m, book_s = convert_time(elapsed_time)
total_h, total_m, total_s = convert_time(total_elapsed_time)
eta_h, eta_m, eta_s = convert_time(eta)
print(f"Book processed {book_h}h {book_m}m {book_s}s"
f"\nTotal time elapsed {total_h}h {total_m}m {total_s}s"
f"\nETA for remaining books: {eta_h}h {eta_m}m {eta_s}s")
def train_model_simple(model, optimizer, device, n_epochs,
eval_freq, eval_iter, print_sample_iter, start_context,
output_dir, save_ckpt_freq,
batch_size=1024, train_ratio=0.90):
train_losses, val_losses, track_tokens_seen = [], [], []
tokens_seen = 0
global_step = -1
start_time = time.time()
try:
for epoch in range(n_epochs):
# Iterate over the books in the training corpus
for index, file_path in enumerate(all_files, 1):
book_start_time = time.time()
text_data = read_text_file(file_path) + " <|endoftext|> "
print(f"Tokenizing file {index} of {total_files}: {file_path}")
# Initialize new data loaders for each book
train_loader, val_loader = create_dataloaders(
text_data,
train_ratio=train_ratio,
batch_size=batch_size,
max_length=GPT_CONFIG_124M["ctx_len"],
stride=GPT_CONFIG_124M["ctx_len"]
)
print(f"Training ...")
model.train()
for input_batch, target_batch in train_loader:
optimizer.zero_grad()
loss = calc_loss_batch(input_batch, target_batch, model, device)
loss.backward()
optimizer.step()
tokens_seen += input_batch.numel()
global_step += 1
# Optional evaluation step
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter)
train_losses.append(train_loss)
val_losses.append(val_loss)
track_tokens_seen.append(tokens_seen)
print(f"Ep {epoch+1} (Step {global_step}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
# Generate text passage
if index % print_sample_iter == 0:
generate_and_print_sample(
model, train_loader.dataset.tokenizer, device, start_context
)
if global_step % save_ckpt_freq:
file_name = output_dir / f"model_pg_{global_step}.pth"
torch.save(model.state_dict(), file_name)
print(f"Saved {file_name}")
print_eta(start_time, book_start_time, index, total_files)
except KeyboardInterrupt:
file_name = output_dir / f"model_pg_{global_step}_interrupted.pth"
torch.save(model.state_dict(), file_name)
print(f"Saved {file_name}")
return train_losses, val_losses, tokens_seen
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GPT Model Training Configuration')
parser.add_argument('--data_dir', type=str, default='gutenberg/data',
help='Directory containing the training data')
parser.add_argument('--output_dir', type=str, default='model_checkpoints',
help='Directory where the model checkpoints will be saved')
parser.add_argument('--n_epochs', type=int, default=1,
help='Number of epochs to train the model')
parser.add_argument('--print_sample_iter', type=int, default=500,
help='Iterations between printing sample outputs')
parser.add_argument('--eval_freq', type=int, default=100,
help='Frequency of evaluations during training')
parser.add_argument('--save_ckpt_freq', type=int, default=100_000,
help='Frequency of saving model checkpoints during training')
parser.add_argument('--lr', type=float, default=5e-4,
help='Learning rate for the optimizer')
parser.add_argument('--batch_size', type=int, default=4,
help='Batch size for training')
args = parser.parse_args()
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"ctx_len": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-key-value bias
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1)
data_dir = args.data_dir
all_files = [os.path.join(path, name) for path, subdirs, files
in os.walk(data_dir) for name in files if name.endswith((".txt"))]
total_files = len(all_files)
if total_files == 0:
print("No training text files found. Make sure you "
"selected the correct input directory")
quit()
print("Total files:", total_files)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
train_losses, val_losses, tokens_seen = train_model_simple(
model, optimizer, device,
batch_size=args.batch_size,
n_epochs=args.n_epochs,
eval_freq=args.eval_freq,
eval_iter=1,
print_sample_iter=args.print_sample_iter,
output_dir=output_dir,
save_ckpt_freq=args.save_ckpt_freq,
start_context="Every effort moves you",
)
epochs_tensor = torch.linspace(1, args.n_epochs, len(train_losses))
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
torch.save(model.state_dict(), output_dir / "model_pg_final.pth")
print(f"Maximum GPU memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

View File

@ -0,0 +1,313 @@
# This file collects all the relevant code that we covered thus far
# throughout Chapters 2-4.
# This file can be run as a standalone script.
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
#####################################
# Chapter 2
#####################################
class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride):
self.tokenizer = tokenizer
self.input_ids = []
self.target_ids = []
token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})
for i in range(0, len(token_ids) - max_length, stride):
input_chunk = token_ids[i:i + max_length]
target_chunk = token_ids[i + 1: i + max_length + 1]
self.input_ids.append(torch.tensor(input_chunk))
self.target_ids.append(torch.tensor(target_chunk))
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.target_ids[idx]
def create_dataloader_v1(txt, batch_size=4, max_length=256,
stride=128, shuffle=True, drop_last=True):
tokenizer = tiktoken.get_encoding("gpt2")
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
return dataloader
#####################################
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
#####################################
# Chapter 4
#####################################
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
nn.Dropout(cfg["drop_rate"])
)
def forward(self, x):
return self.layers(x)
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
block_size=cfg["ctx_len"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_resid = nn.Dropout(cfg["drop_rate"])
def forward(self, x):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
x = self.drop_resid(x)
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_resid(x)
x = x + shortcut # Add the original input back
return x
class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits
def generate_text_simple(model, idx, max_new_tokens, context_size):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop current context if it exceeds the supported context size
# E.g., if LLM supports only 5 tokens, and the context size is 10
# then only the last 5 tokens are used as context
idx_cond = idx[:, -context_size:]
# Get the predictions
with torch.no_grad():
logits = model(idx_cond)
# Focus only on the last time step
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
logits = logits[:, -1, :]
# Get the idx of the vocab entry with the highest logits value
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
return idx
#####################################
# Chapter 5
####################################
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)
logits = logits.view(-1, logits.size(-1))
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
return loss
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss, batches_seen = 0., 0.
if num_batches is None:
num_batches = len(data_loader)
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
batches_seen += 1
else:
break
return total_loss / batches_seen
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
model.train()
return train_loss, val_loss
def generate_and_print_sample(model, tokenizer, device, start_context):
model.eval()
context_size = model.pos_emb.weight.shape[0]
encoded = text_to_token_ids(start_context, tokenizer).to(device)
with torch.no_grad():
token_ids = generate_text_simple(model=model, idx=encoded,
max_new_tokens=50, context_size=context_size)
decoded_text = token_ids_to_text(token_ids, tokenizer)
print(decoded_text.replace("\n", " ")) # Compact print format
model.train()
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir):
fig, ax1 = plt.subplots()
# Plot training and validation loss against epochs
ax1.plot(epochs_seen, train_losses, label="Training loss")
ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax1.legend(loc="upper right")
# Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
ax2.set_xlabel("Tokens seen")
fig.tight_layout() # Adjust layout to make room
plt.savefig(output_dir / "losses.pdf")
def text_to_token_ids(text, tokenizer):
encoded = tokenizer.encode(text)
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
return encoded_tensor
def token_ids_to_text(token_ids, tokenizer):
flat = token_ids.squeeze(0) # remove batch dimension
return tokenizer.decode(flat.tolist())