Merge pull request #153 from rasbt/ch06-exercises

Chapter 6 wrap-up
This commit is contained in:
Sebastian Raschka 2024-05-13 08:14:08 -05:00 committed by GitHub
commit 968af7e0ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 624 additions and 9 deletions

View File

@ -38,9 +38,10 @@ jobs:
- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest setup/02_installing-python-libraries/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
run: |

View File

@ -38,9 +38,10 @@ jobs:
- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest setup/02_installing-python-libraries/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
run: |

View File

@ -41,9 +41,10 @@ jobs:
- name: Test Selected Python Scripts
shell: bash
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest setup/02_installing-python-libraries/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
shell: bash

2
.gitignore vendored
View File

@ -6,6 +6,8 @@ appendix-D/01_main-chapter-code/3.pdf
ch05/01_main-chapter-code/loss-plot.pdf
ch05/01_main-chapter-code/temperature-plot.pdf
ch05/01_main-chapter-code/the-verdict.txt
ch06/01_main-chapter-code/loss-plot.pdf
ch06/01_main-chapter-code/accuracy-plot.pdf
# Checkpoint files
ch05/01_main-chapter-code/gpt2/

View File

@ -52,7 +52,7 @@ Alternatively, you can view this and other files on GitHub at [https://github.co
| Ch 3: Coding Attention Mechanisms | - [ch03.ipynb](ch03/01_main-chapter-code/ch03.ipynb)<br/>- [multihead-attention.ipynb](ch03/01_main-chapter-code/multihead-attention.ipynb) (summary) <br/>- [exercise-solutions.ipynb](ch03/01_main-chapter-code/exercise-solutions.ipynb)| [./ch03](./ch03) |
| Ch 4: Implementing a GPT Model from Scratch | - [ch04.ipynb](ch04/01_main-chapter-code/ch04.ipynb)<br/>- [gpt.py](ch04/01_main-chapter-code/gpt.py) (summary)<br/>- [exercise-solutions.ipynb](ch04/01_main-chapter-code/exercise-solutions.ipynb) | [./ch04](./ch04) |
| Ch 5: Pretraining on Unlabeled Data | - [ch05.ipynb](ch05/01_main-chapter-code/ch05.ipynb)<br/>- [gpt_train.py](ch05/01_main-chapter-code/gpt_train.py) (summary) <br/>- [gpt_generate.py](ch05/01_main-chapter-code/gpt_generate.py) (summary) <br/>- [exercise-solutions.ipynb](ch05/01_main-chapter-code/exercise-solutions.ipynb) | [./ch05](./ch05) |
| Ch 6: Finetuning for Text Classification | - [ch06.ipynb](ch06/01_main-chapter-code/ch06.ipynb) | [./ch06](./ch06) |
| Ch 6: Finetuning for Text Classification | - [ch06.ipynb](ch06/01_main-chapter-code/ch06.ipynb) <br/>- [gpt-class-finetune.py](ch06/01_main-chapter-code/gpt-class-finetune.py) <br/>- [exercise-solutions.ipynb](ch06/01_main-chapter-code/exercise-solutions.ipynb) | [./ch06](./ch06) |
| Ch 7: Finetuning with Human Feedback | Q2 2024 | ... |
| Appendix A: Introduction to PyTorch | - [code-part1.ipynb](appendix-A/01_main-chapter-code/code-part1.ipynb)<br/>- [code-part2.ipynb](appendix-A/01_main-chapter-code/code-part2.ipynb)<br/>- [DDP-script.py](appendix-A/01_main-chapter-code/DDP-script.py)<br/>- [exercise-solutions.ipynb](appendix-A/01_main-chapter-code/exercise-solutions.ipynb) | [./appendix-A](./appendix-A) |
| Appendix B: References and Further Reading | No code | - |

View File

@ -199,8 +199,8 @@
}
],
"source": [
"total_steps = len(train_loader) * n_epochs * train_loader.batch_size\n",
"warmup_steps = int(0.1 * total_steps) # 10% warmup\n",
"total_steps = len(train_loader) * n_epochs\n",
"warmup_steps = int(0.2 * total_steps) # 20% warmup\n",
"print(warmup_steps)"
]
},
@ -779,7 +779,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -1513,7 +1513,7 @@
"id": "669e1fd1-ace8-44b4-b438-185ed0ba8b33",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/overview-3.webp\" width=500px>"
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/overview-3.webp?123\" width=500px>"
]
},
{
@ -1524,6 +1524,14 @@
"- Before explaining the loss calculation, let's have a brief look at how the model outputs are turned into class labels"
]
},
{
"cell_type": "markdown",
"id": "557996dd-4c6b-49c4-ab83-f60ef7e1d69e",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/class-argmax.webp\" width=600px>"
]
},
{
"cell_type": "code",
"execution_count": 26,
@ -2347,7 +2355,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -0,0 +1,168 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ba450fb1-8a26-4894-ab7a-5d7bfefe90ce",
"metadata": {},
"source": [
"<font size=\"1\">\n",
"Supplementary code for \"Build a Large Language Model From Scratch\": <a href=\"https://www.manning.com/books/build-a-large-language-model-from-scratch\">https://www.manning.com/books/build-a-large-language-model-from-scratch</a> by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>"
]
},
{
"cell_type": "markdown",
"id": "51c9672d-8d0c-470d-ac2d-1271f8ec3f14",
"metadata": {},
"source": [
"# Chapter 6 Exercise solutions"
]
},
{
"cell_type": "markdown",
"id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e",
"metadata": {},
"source": [
"## Exercise 6.1: Increasing the context length"
]
},
{
"cell_type": "markdown",
"id": "5860ba9f-2db3-4480-b96b-4be1c68981eb",
"metadata": {},
"source": [
"We can pad the inputs to the maximum number of tokens to the maximum the model supports by setting the max length to\n",
"\n",
"```python\n",
"max_length = 1024\n",
"\n",
"train_dataset = SpamDataset(base_path / \"train.csv\", max_length=max_length, tokenizer=tokenizer)\n",
"val_dataset = SpamDataset(base_path / \"validation.csv\", max_length=max_length, tokenizer=tokenizer)\n",
"test_dataset = SpamDataset(base_path / \"test.csv\", max_length=max_length, tokenizer=tokenizer)\n",
"\n",
"```\n",
"\n",
"or, equivalently, we can define the `max_length` via:\n",
"\n",
"```python\n",
"max_length = model.pos_emb.weight.shape[0]\n",
"```\n",
"\n",
"or\n",
"\n",
"```python\n",
"max_length = BASE_CONFIG[\"context_length\"]\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "2b0f4d5d-17fd-4265-93d8-ea08a22fdaf8",
"metadata": {},
"source": [
"For convenience, you can run this experiment via\n",
"\n",
"```\n",
"python additional-experiments.py --context_length \"model_context_length\"\n",
"```\n",
"\n",
"using the code in the [../02_bonus_additional-experiments](../02_bonus_additional-experiments) folder, which results in a substantially worse test accuracy of 78.33% (versus the 95.67% in the main chapter)."
]
},
{
"cell_type": "markdown",
"id": "5a780455-f52a-48d1-ab82-6afd40bcad8b",
"metadata": {},
"source": [
"## Exercise 6.2: Finetuning the whole model"
]
},
{
"cell_type": "markdown",
"id": "56aa5208-aa29-4165-a0ec-7480754e2a18",
"metadata": {},
"source": [
"Instead of finetuning just the final transformer block, we can finetune the entire model by removing the following lines from the code:\n",
"\n",
"```python\n",
"for param in model.parameters():\n",
" param.requires_grad = False\n",
"```\n",
"\n",
"For convenience, you can run this experiment via\n",
"\n",
"```\n",
"python additional-experiments.py --trainable_layers all\n",
"```\n",
"\n",
"using the code in the [../02_bonus_additional-experiments](../02_bonus_additional-experiments) folder, which results in a 1% improved test accuracy of 96.67% (versus the 95.67% in the main chapter)."
]
},
{
"cell_type": "markdown",
"id": "2269bce3-f2b5-4a76-a692-5977c75a57b6",
"metadata": {},
"source": [
"## Exercise 6.3: Finetuning the first versus last token "
]
},
{
"cell_type": "markdown",
"id": "7418a629-51b6-4aa2-83b7-bc0261bc370f",
"metadata": {},
"source": [
"ther than finetuning the last output token, we can finetune the first output token by changing \n",
"\n",
"```python\n",
"model(input_batch)[:, -1, :]\n",
"```\n",
"\n",
"to\n",
"\n",
"```python\n",
"model(input_batch)[:, 0, :]\n",
"```\n",
"\n",
"everywhere in the code.\n",
"\n",
"For convenience, you can run this experiment via\n",
"\n",
"```\n",
"python additional-experiments.py --trainable_token first\n",
"```\n",
"\n",
"using the code in the [../02_bonus_additional-experiments](../02_bonus_additional-experiments) folder, which results in a substantially worse test accuracy of 75.00% (versus the 95.67% in the main chapter)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5e6188a-f182-4f26-b9e5-ccae3ecadae0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,418 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
# This is a summary file containing the main takeaways from chapter 6.
import urllib.request
import zipfile
import os
from pathlib import Path
import time
import matplotlib.pyplot as plt
import pandas as pd
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoader
from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
if data_file_path.exists():
print(f"{data_file_path} already exists. Skipping download and extraction.")
return
# Downloading the file
with urllib.request.urlopen(url) as response:
with open(zip_path, "wb") as out_file:
out_file.write(response.read())
# Unzipping the file
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extracted_path)
# Add .tsv file extension
original_file_path = Path(extracted_path) / "SMSSpamCollection"
os.rename(original_file_path, data_file_path)
print(f"File downloaded and saved as {data_file_path}")
def create_balanced_dataset(df):
# Count the instances of "spam"
num_spam = df[df["Label"] == "spam"].shape[0]
# Randomly sample "ham" instances to match the number of "spam" instances
ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
# Combine ham "subset" with "spam"
balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
return balanced_df
def random_split(df, train_frac, validation_frac):
# Shuffle the entire DataFrame
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
# Calculate split indices
train_end = int(len(df) * train_frac)
validation_end = train_end + int(len(df) * validation_frac)
# Split the DataFrame
train_df = df[:train_end]
validation_df = df[train_end:validation_end]
test_df = df[validation_end:]
return train_df, validation_df, test_df
class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file)
# Pre-tokenize texts
self.encoded_texts = [
tokenizer.encode(text) for text in self.data["Text"]
]
if max_length is None:
self.max_length = self._longest_encoded_length()
else:
self.max_length = max_length
# Truncate sequences if they are longer than max_length
self.encoded_texts = [
encoded_text[:self.max_length]
for encoded_text in self.encoded_texts
]
# Pad sequences to the longest sequence
self.encoded_texts = [
encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
for encoded_text in self.encoded_texts
]
def __getitem__(self, index):
encoded = self.encoded_texts[index]
label = self.data.iloc[index]["Label"]
return (
torch.tensor(encoded, dtype=torch.long),
torch.tensor(label, dtype=torch.long)
)
def __len__(self):
return len(self.data)
def _longest_encoded_length(self):
max_length = 0
for encoded_text in self.encoded_texts:
encoded_length = len(encoded_text)
if encoded_length > max_length:
max_length = encoded_length
return max_length
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
else:
break
return total_loss / num_batches
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
model.train()
return train_loss, val_loss
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, tokenizer):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
# Main training loop
for epoch in range(num_epochs):
model.train() # Set model to training mode
for input_batch, target_batch in train_loader:
optimizer.zero_grad() # Reset loss gradients from previous epoch
loss = calc_loss_batch(input_batch, target_batch, model, device)
loss.backward() # Calculate loss gradients
optimizer.step() # Update model weights using loss gradients
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
global_step += 1
# Optional evaluation step
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
# Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
val_accs.append(val_accuracy)
return train_losses, val_losses, train_accs, val_accs, examples_seen
def plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"):
fig, ax1 = plt.subplots(figsize=(5, 3))
# Plot training and validation loss against epochs
ax1.plot(epochs_seen, train_values, label=f"Training {label}")
ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")
ax1.set_xlabel("Epochs")
ax1.set_ylabel(label.capitalize())
ax1.legend()
# Create a second x-axis for tokens seen
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks
ax2.set_xlabel("Examples seen")
fig.tight_layout() # Adjust layout to make room
plt.savefig(f"{label}-plot.pdf")
# plt.show()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Finetune a GPT model for classification"
)
parser.add_argument(
"--test_mode",
action="store_true",
help=("This flag runs the model in test mode for internal testing purposes. "
"Otherwise, it runs the model as it is used in the chapter (recommended).")
)
args = parser.parse_args()
########################################
# Download and prepare dataset
########################################
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extracted_path = "sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
balanced_df = create_balanced_dataset(df)
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)
########################################
# Create data loaders
########################################
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = SpamDataset(
csv_file="train.csv",
max_length=None,
tokenizer=tokenizer
)
val_dataset = SpamDataset(
csv_file="validation.csv",
max_length=train_dataset.max_length,
tokenizer=tokenizer
)
test_dataset = SpamDataset(
csv_file="test.csv",
max_length=train_dataset.max_length,
tokenizer=tokenizer
)
num_workers = 0
batch_size = 8
torch.manual_seed(123)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
########################################
# Load pretrained model
########################################
# Small GPT model for testing purposes
if args.test_mode:
BASE_CONFIG = {
"vocab_size": 50257,
"context_length": 120,
"drop_rate": 0.0,
"qkv_bias": False,
"emb_dim": 12,
"n_layers": 1,
"n_heads": 2
}
model = GPTModel(BASE_CONFIG)
model.eval()
device = "cpu"
model.to(device)
# Code as it is used in the main chapter
else:
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
########################################
# Modify and pretrained model
########################################
for param in model.parameters():
param.requires_grad = False
torch.manual_seed(123)
num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
########################################
# Finetune modified model
########################################
start_time = time.time()
torch.manual_seed(123)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
num_epochs = 5
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=num_epochs, eval_freq=50, eval_iter=5,
tokenizer=tokenizer
)
end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")
########################################
# Plot results
########################################
# loss plot
epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))
plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)
# accuracy plot
epochs_tensor = torch.linspace(0, num_epochs, len(train_accs))
examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs))
plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy")

View File

@ -0,0 +1,16 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
# File for internal use (unit tests)
import subprocess
def test_gpt_class_finetune():
command = ["python", "ch06/01_main-chapter-code/gpt-class-finetune.py", "--test_mode"]
result = subprocess.run(command, capture_output=True, text=True)
assert result.returncode == 0, f"Script exited with errors: {result.stderr}"