mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-07-03 07:04:25 +00:00
545 lines
20 KiB
Python
545 lines
20 KiB
Python
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
|
# Source for "Build a Large Language Model From Scratch"
|
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
|
#
|
|
# This file collects all the relevant code that we covered thus far
|
|
# throughout Chapters 2-6.
|
|
# This file can be run as a standalone script.
|
|
|
|
import os
|
|
from pathlib import Path
|
|
import urllib
|
|
import zipfile
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
import tiktoken
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
#####################################
|
|
# Chapter 2
|
|
#####################################
|
|
|
|
|
|
class GPTDatasetV1(Dataset):
|
|
def __init__(self, txt, tokenizer, max_length, stride):
|
|
self.input_ids = []
|
|
self.target_ids = []
|
|
|
|
# Tokenize the entire text
|
|
token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
|
|
|
|
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
|
for i in range(0, len(token_ids) - max_length, stride):
|
|
input_chunk = token_ids[i:i + max_length]
|
|
target_chunk = token_ids[i + 1: i + max_length + 1]
|
|
self.input_ids.append(torch.tensor(input_chunk))
|
|
self.target_ids.append(torch.tensor(target_chunk))
|
|
|
|
def __len__(self):
|
|
return len(self.input_ids)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.input_ids[idx], self.target_ids[idx]
|
|
|
|
|
|
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
|
stride=128, shuffle=True, drop_last=True):
|
|
# Initialize the tokenizer
|
|
tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
# Create dataset
|
|
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
|
|
|
|
# Create dataloader
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
|
|
|
return dataloader
|
|
|
|
|
|
#####################################
|
|
# Chapter 3
|
|
#####################################
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
|
|
super().__init__()
|
|
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
|
|
|
|
self.d_out = d_out
|
|
self.num_heads = num_heads
|
|
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
|
|
|
|
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
|
|
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
|
|
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
|
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
|
|
|
|
def forward(self, x):
|
|
b, num_tokens, d_in = x.shape
|
|
|
|
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
|
queries = self.W_query(x)
|
|
values = self.W_value(x)
|
|
|
|
# We implicitly split the matrix by adding a `num_heads` dimension
|
|
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
|
|
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
|
|
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
|
|
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
|
|
|
|
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
|
|
keys = keys.transpose(1, 2)
|
|
queries = queries.transpose(1, 2)
|
|
values = values.transpose(1, 2)
|
|
|
|
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
|
|
|
# Original mask truncated to the number of tokens and converted to boolean
|
|
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
|
|
|
# Use the mask to fill attention scores
|
|
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
|
|
|
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
|
attn_weights = self.dropout(attn_weights)
|
|
|
|
# Shape: (b, num_tokens, num_heads, head_dim)
|
|
context_vec = (attn_weights @ values).transpose(1, 2)
|
|
|
|
# Combine heads, where self.d_out = self.num_heads * self.head_dim
|
|
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
|
|
context_vec = self.out_proj(context_vec) # optional projection
|
|
|
|
return context_vec
|
|
|
|
|
|
#####################################
|
|
# Chapter 4
|
|
#####################################
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, emb_dim):
|
|
super().__init__()
|
|
self.eps = 1e-5
|
|
self.scale = nn.Parameter(torch.ones(emb_dim))
|
|
self.shift = nn.Parameter(torch.zeros(emb_dim))
|
|
|
|
def forward(self, x):
|
|
mean = x.mean(dim=-1, keepdim=True)
|
|
var = x.var(dim=-1, keepdim=True, unbiased=False)
|
|
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
|
return self.scale * norm_x + self.shift
|
|
|
|
|
|
class GELU(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return 0.5 * x * (1 + torch.tanh(
|
|
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
|
|
(x + 0.044715 * torch.pow(x, 3))
|
|
))
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.layers = nn.Sequential(
|
|
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
|
|
GELU(),
|
|
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layers(x)
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.att = MultiHeadAttention(
|
|
d_in=cfg["emb_dim"],
|
|
d_out=cfg["emb_dim"],
|
|
context_length=cfg["context_length"],
|
|
num_heads=cfg["n_heads"],
|
|
dropout=cfg["drop_rate"],
|
|
qkv_bias=cfg["qkv_bias"])
|
|
self.ff = FeedForward(cfg)
|
|
self.norm1 = LayerNorm(cfg["emb_dim"])
|
|
self.norm2 = LayerNorm(cfg["emb_dim"])
|
|
self.drop_resid = nn.Dropout(cfg["drop_rate"])
|
|
|
|
def forward(self, x):
|
|
# Shortcut connection for attention block
|
|
shortcut = x
|
|
x = self.norm1(x)
|
|
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
|
|
x = self.drop_resid(x)
|
|
x = x + shortcut # Add the original input back
|
|
|
|
# Shortcut connection for feed-forward block
|
|
shortcut = x
|
|
x = self.norm2(x)
|
|
x = self.ff(x)
|
|
x = self.drop_resid(x)
|
|
x = x + shortcut # Add the original input back
|
|
|
|
return x
|
|
|
|
|
|
class GPTModel(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
|
|
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
|
|
self.drop_emb = nn.Dropout(cfg["drop_rate"])
|
|
|
|
self.trf_blocks = nn.Sequential(
|
|
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
|
|
|
self.final_norm = LayerNorm(cfg["emb_dim"])
|
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
|
|
|
|
def forward(self, in_idx):
|
|
batch_size, seq_len = in_idx.shape
|
|
tok_embeds = self.tok_emb(in_idx)
|
|
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
|
|
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
|
|
x = self.drop_emb(x)
|
|
x = self.trf_blocks(x)
|
|
x = self.final_norm(x)
|
|
logits = self.out_head(x)
|
|
return logits
|
|
|
|
|
|
def generate_text_simple(model, idx, max_new_tokens, context_size):
|
|
# idx is (B, T) array of indices in the current context
|
|
for _ in range(max_new_tokens):
|
|
|
|
# Crop current context if it exceeds the supported context size
|
|
# E.g., if LLM supports only 5 tokens, and the context size is 10
|
|
# then only the last 5 tokens are used as context
|
|
idx_cond = idx[:, -context_size:]
|
|
|
|
# Get the predictions
|
|
with torch.no_grad():
|
|
logits = model(idx_cond)
|
|
|
|
# Focus only on the last time step
|
|
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
|
|
logits = logits[:, -1, :]
|
|
|
|
# Get the idx of the vocab entry with the highest logits value
|
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
|
|
|
|
# Append sampled index to the running sequence
|
|
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
|
|
|
|
return idx
|
|
|
|
|
|
#####################################
|
|
# Chapter 5
|
|
#####################################
|
|
def assign(left, right):
|
|
if left.shape != right.shape:
|
|
raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
|
|
return torch.nn.Parameter(torch.tensor(right))
|
|
|
|
|
|
def load_weights_into_gpt(gpt, params):
|
|
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
|
|
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
|
|
|
|
for b in range(len(params["blocks"])):
|
|
q_w, k_w, v_w = np.split(
|
|
(params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
|
|
gpt.trf_blocks[b].att.W_query.weight = assign(
|
|
gpt.trf_blocks[b].att.W_query.weight, q_w.T)
|
|
gpt.trf_blocks[b].att.W_key.weight = assign(
|
|
gpt.trf_blocks[b].att.W_key.weight, k_w.T)
|
|
gpt.trf_blocks[b].att.W_value.weight = assign(
|
|
gpt.trf_blocks[b].att.W_value.weight, v_w.T)
|
|
|
|
q_b, k_b, v_b = np.split(
|
|
(params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
|
|
gpt.trf_blocks[b].att.W_query.bias = assign(
|
|
gpt.trf_blocks[b].att.W_query.bias, q_b)
|
|
gpt.trf_blocks[b].att.W_key.bias = assign(
|
|
gpt.trf_blocks[b].att.W_key.bias, k_b)
|
|
gpt.trf_blocks[b].att.W_value.bias = assign(
|
|
gpt.trf_blocks[b].att.W_value.bias, v_b)
|
|
|
|
gpt.trf_blocks[b].att.out_proj.weight = assign(
|
|
gpt.trf_blocks[b].att.out_proj.weight,
|
|
params["blocks"][b]["attn"]["c_proj"]["w"].T)
|
|
gpt.trf_blocks[b].att.out_proj.bias = assign(
|
|
gpt.trf_blocks[b].att.out_proj.bias,
|
|
params["blocks"][b]["attn"]["c_proj"]["b"])
|
|
|
|
gpt.trf_blocks[b].ff.layers[0].weight = assign(
|
|
gpt.trf_blocks[b].ff.layers[0].weight,
|
|
params["blocks"][b]["mlp"]["c_fc"]["w"].T)
|
|
gpt.trf_blocks[b].ff.layers[0].bias = assign(
|
|
gpt.trf_blocks[b].ff.layers[0].bias,
|
|
params["blocks"][b]["mlp"]["c_fc"]["b"])
|
|
gpt.trf_blocks[b].ff.layers[2].weight = assign(
|
|
gpt.trf_blocks[b].ff.layers[2].weight,
|
|
params["blocks"][b]["mlp"]["c_proj"]["w"].T)
|
|
gpt.trf_blocks[b].ff.layers[2].bias = assign(
|
|
gpt.trf_blocks[b].ff.layers[2].bias,
|
|
params["blocks"][b]["mlp"]["c_proj"]["b"])
|
|
|
|
gpt.trf_blocks[b].norm1.scale = assign(
|
|
gpt.trf_blocks[b].norm1.scale,
|
|
params["blocks"][b]["ln_1"]["g"])
|
|
gpt.trf_blocks[b].norm1.shift = assign(
|
|
gpt.trf_blocks[b].norm1.shift,
|
|
params["blocks"][b]["ln_1"]["b"])
|
|
gpt.trf_blocks[b].norm2.scale = assign(
|
|
gpt.trf_blocks[b].norm2.scale,
|
|
params["blocks"][b]["ln_2"]["g"])
|
|
gpt.trf_blocks[b].norm2.shift = assign(
|
|
gpt.trf_blocks[b].norm2.shift,
|
|
params["blocks"][b]["ln_2"]["b"])
|
|
|
|
gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
|
|
gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
|
|
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
|
|
|
|
|
|
def text_to_token_ids(text, tokenizer):
|
|
encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
|
|
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
|
|
return encoded_tensor
|
|
|
|
|
|
def token_ids_to_text(token_ids, tokenizer):
|
|
flat = token_ids.squeeze(0) # remove batch dimension
|
|
return tokenizer.decode(flat.tolist())
|
|
|
|
|
|
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
|
total_loss = 0.
|
|
if len(data_loader) == 0:
|
|
return float("nan")
|
|
elif num_batches is None:
|
|
num_batches = len(data_loader)
|
|
else:
|
|
# Reduce the number of batches to match the total number of batches in the data loader
|
|
# if num_batches exceeds the number of batches in the data loader
|
|
num_batches = min(num_batches, len(data_loader))
|
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
|
if i < num_batches:
|
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
|
total_loss += loss.item()
|
|
else:
|
|
break
|
|
return total_loss / num_batches
|
|
|
|
|
|
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
|
|
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
|
|
model.train()
|
|
return train_loss, val_loss
|
|
|
|
|
|
#####################################
|
|
# Chapter 6
|
|
#####################################
|
|
|
|
|
|
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
|
|
if data_file_path.exists():
|
|
print(f"{data_file_path} already exists. Skipping download and extraction.")
|
|
return
|
|
|
|
# Downloading the file
|
|
with urllib.request.urlopen(url) as response:
|
|
with open(zip_path, "wb") as out_file:
|
|
out_file.write(response.read())
|
|
|
|
# Unzipping the file
|
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
|
zip_ref.extractall(extracted_path)
|
|
|
|
# Add .tsv file extension
|
|
original_file_path = Path(extracted_path) / "SMSSpamCollection"
|
|
os.rename(original_file_path, data_file_path)
|
|
print(f"File downloaded and saved as {data_file_path}")
|
|
|
|
|
|
def create_balanced_dataset(df):
|
|
|
|
# Count the instances of "spam"
|
|
num_spam = df[df["Label"] == "spam"].shape[0]
|
|
|
|
# Randomly sample "ham' instances to match the number of 'spam' instances
|
|
ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
|
|
|
|
# Combine ham "subset" with "spam"
|
|
balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
|
|
|
|
return balanced_df
|
|
|
|
|
|
def random_split(df, train_frac, validation_frac):
|
|
# Shuffle the entire DataFrame
|
|
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
|
|
|
|
# Calculate split indices
|
|
train_end = int(len(df) * train_frac)
|
|
validation_end = train_end + int(len(df) * validation_frac)
|
|
|
|
# Split the DataFrame
|
|
train_df = df[:train_end]
|
|
validation_df = df[train_end:validation_end]
|
|
test_df = df[validation_end:]
|
|
|
|
return train_df, validation_df, test_df
|
|
|
|
|
|
class SpamDataset(Dataset):
|
|
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
|
|
self.data = pd.read_csv(csv_file)
|
|
|
|
# Pre-tokenize texts
|
|
self.encoded_texts = [
|
|
tokenizer.encode(text) for text in self.data["Text"]
|
|
]
|
|
|
|
if max_length is None:
|
|
self.max_length = self._longest_encoded_length()
|
|
else:
|
|
self.max_length = max_length
|
|
# Truncate sequences if they are longer than max_length
|
|
self.encoded_texts = [
|
|
encoded_text[:self.max_length]
|
|
for encoded_text in self.encoded_texts
|
|
]
|
|
|
|
# Pad sequences to the longest sequence
|
|
self.encoded_texts = [
|
|
encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
|
|
for encoded_text in self.encoded_texts
|
|
]
|
|
|
|
def __getitem__(self, index):
|
|
encoded = self.encoded_texts[index]
|
|
label = self.data.iloc[index]["Label"]
|
|
return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def _longest_encoded_length(self):
|
|
max_length = 0
|
|
for encoded_text in self.encoded_texts:
|
|
encoded_length = len(encoded_text)
|
|
if encoded_length > max_length:
|
|
max_length = encoded_length
|
|
return max_length
|
|
# Note: A more pythonic version to implement this method
|
|
# is the following, which is also used in the next chapter:
|
|
# return max(len(encoded_text) for encoded_text in self.encoded_texts)
|
|
|
|
|
|
@torch.no_grad() # Disable gradient tracking for efficiency
|
|
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
|
model.eval()
|
|
correct_predictions, num_examples = 0, 0
|
|
|
|
if num_batches is None:
|
|
num_batches = len(data_loader)
|
|
else:
|
|
num_batches = min(num_batches, len(data_loader))
|
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
|
if i < num_batches:
|
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
|
logits = model(input_batch)[:, -1, :] # Logits of last output token
|
|
predicted_labels = torch.argmax(logits, dim=-1)
|
|
|
|
num_examples += predicted_labels.shape[0]
|
|
correct_predictions += (predicted_labels == target_batch).sum().item()
|
|
else:
|
|
break
|
|
return correct_predictions / num_examples
|
|
|
|
|
|
def calc_loss_batch(input_batch, target_batch, model, device):
|
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
|
logits = model(input_batch)[:, -1, :] # Logits of last output token
|
|
loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
|
return loss
|
|
|
|
|
|
# Overall the same as `train_model_simple` in chapter 5
|
|
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
|
eval_freq, eval_iter):
|
|
# Initialize lists to track losses and tokens seen
|
|
train_losses, val_losses, train_accs, val_accs = [], [], [], []
|
|
examples_seen, global_step = 0, -1
|
|
|
|
# Main training loop
|
|
for epoch in range(num_epochs):
|
|
model.train() # Set model to training mode
|
|
|
|
for input_batch, target_batch in train_loader:
|
|
optimizer.zero_grad() # Reset loss gradients from previous batch iteration
|
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
|
loss.backward() # Calculate loss gradients
|
|
optimizer.step() # Update model weights using loss gradients
|
|
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
|
|
global_step += 1
|
|
|
|
# Optional evaluation step
|
|
if global_step % eval_freq == 0:
|
|
train_loss, val_loss = evaluate_model(
|
|
model, train_loader, val_loader, device, eval_iter)
|
|
train_losses.append(train_loss)
|
|
val_losses.append(val_loss)
|
|
print(f"Ep {epoch+1} (Step {global_step:06d}): "
|
|
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
|
|
|
|
# Calculate accuracy after each epoch
|
|
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
|
|
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
|
|
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
|
|
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
|
train_accs.append(train_accuracy)
|
|
val_accs.append(val_accuracy)
|
|
|
|
return train_losses, val_losses, train_accs, val_accs, examples_seen
|
|
|
|
|
|
def plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"):
|
|
fig, ax1 = plt.subplots(figsize=(5, 3))
|
|
|
|
# Plot training and validation loss against epochs
|
|
ax1.plot(epochs_seen, train_values, label=f"Training {label}")
|
|
ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")
|
|
ax1.set_xlabel("Epochs")
|
|
ax1.set_ylabel(label.capitalize())
|
|
ax1.legend()
|
|
|
|
# Create a second x-axis for tokens seen
|
|
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
|
|
ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks
|
|
ax2.set_xlabel("Examples seen")
|
|
|
|
fig.tight_layout() # Adjust layout to make room
|
|
plt.savefig(f"{label}-plot.pdf")
|
|
plt.show()
|