{
"cells": [
{
"cell_type": "markdown",
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b",
"metadata": {
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b"
},
"source": [
"\n",
"Supplementary code for \"Build a Large Language Model From Scratch\": https://www.manning.com/books/build-a-large-language-model-from-scratch by Sebastian Raschka
\n",
"Code repository: https://github.com/rasbt/LLMs-from-scratch\n",
""
]
},
{
"cell_type": "markdown",
"id": "bfabadb8-5935-45ff-b39c-db7a29012129",
"metadata": {
"id": "bfabadb8-5935-45ff-b39c-db7a29012129"
},
"source": [
"# Chapter 6: Finetuning for Text Classification"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"outputId": "9495f150-9d79-4910-d6e7-6c0d9aae4a41"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"matplotlib version: 3.8.2\n",
"numpy version: 1.26.0\n",
"tiktoken version: 0.5.1\n",
"torch version: 2.2.2\n",
"tensorflow version: 2.15.0\n",
"pandas version: 2.2.1\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\"matplotlib\",\n",
" \"numpy\",\n",
" \"tiktoken\",\n",
" \"torch\",\n",
" \"tensorflow\", # For OpenAI's pretrained weights\n",
" \"pandas\" # Dataset loading\n",
" ]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"id": "a445828a-ff10-4efa-9f60-a2e2aed4c87d",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "markdown",
"id": "3a84cf35-b37f-4c15-8972-dfafc9fadc1c",
"metadata": {
"id": "3a84cf35-b37f-4c15-8972-dfafc9fadc1c"
},
"source": [
"## 6.1 Different categories of finetuning"
]
},
{
"cell_type": "markdown",
"id": "ede3d731-5123-4f02-accd-c670ce50a5a3",
"metadata": {
"id": "ede3d731-5123-4f02-accd-c670ce50a5a3"
},
"source": [
"- No code in this section"
]
},
{
"cell_type": "markdown",
"id": "ac45579d-d485-47dc-829e-43be7f4db57b",
"metadata": {},
"source": [
"- The most common ways to finetune language models are instruction-finetuning and classification finetuning\n",
"- Instruction-finetuning, depicted below, is the topic of the next chapter"
]
},
{
"cell_type": "markdown",
"id": "6c29ef42-46d9-43d4-8bb4-94974e1665e4",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"id": "a7f60321-95b8-46a9-97bf-1d07fda2c3dd",
"metadata": {},
"source": [
"- Classification finetuning, the topic of this chapter, is a procedure you may already be familiar with if you have a background in machine learning -- it's similar to training a convolutional network to classify handwritten digits, for example\n",
"- In classification finetuning, we have a specific number of class labels (for example, \"spam\" and \"not spam\") that the model can output\n",
"- A classification finetuned model can only predict classes it has seen during training (for example, \"spam\" or \"not spam\", whereas an instruction-finetuned model can usually perform many tasks\n",
"- We can think of a classification-finetuned model as a very specialized model; in practice, it is much easier to create a specialized model than a generalist model that performs well on many different tasks"
]
},
{
"cell_type": "markdown",
"id": "0b37a0c4-0bb1-4061-b1fe-eaa4416d52c3",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf",
"metadata": {
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf"
},
"source": [
"## 6.2 Preparing the dataset"
]
},
{
"cell_type": "markdown",
"id": "5f628975-d2e8-4f7f-ab38-92bb868b7067",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"id": "9fbd459f-63fa-4d8c-8499-e23103156c7d",
"metadata": {
"id": "9fbd459f-63fa-4d8c-8499-e23103156c7d"
},
"source": [
"- This section prepares the dataset we use for classification finetuning\n",
"- We use a dataset consisting of spam and non-spam text messages to finetune the LLM to classify them\n",
"- First, we download and unzip the dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"outputId": "424e4423-f623-443c-ab9e-656f9e867559"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sms_spam_collection/SMSSpamCollection.tsv already exists. Skipping download and extraction.\n"
]
}
],
"source": [
"import urllib.request\n",
"import zipfile\n",
"import os\n",
"from pathlib import Path\n",
"\n",
"url = \"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip\"\n",
"zip_path = \"sms_spam_collection.zip\"\n",
"extracted_path = \"sms_spam_collection\"\n",
"data_file_path = Path(extracted_path) / \"SMSSpamCollection.tsv\"\n",
"\n",
"def download_and_unzip(url, zip_path, extracted_path, data_file_path):\n",
" if data_file_path.exists():\n",
" print(f\"{data_file_path} already exists. Skipping download and extraction.\")\n",
" return\n",
"\n",
" # Downloading the file\n",
" with urllib.request.urlopen(url) as response:\n",
" with open(zip_path, \"wb\") as out_file:\n",
" out_file.write(response.read())\n",
"\n",
" # Unzipping the file\n",
" with zipfile.ZipFile(zip_path, \"r\") as zip_ref:\n",
" zip_ref.extractall(extracted_path)\n",
"\n",
" # Add .tsv file extension\n",
" original_file_path = Path(extracted_path) / \"SMSSpamCollection\"\n",
" os.rename(original_file_path, data_file_path)\n",
" print(f\"File downloaded and saved as {data_file_path}\")\n",
"\n",
"download_and_unzip(url, zip_path, extracted_path, data_file_path)"
]
},
{
"cell_type": "markdown",
"id": "6aac2d19-06d0-4005-916b-0bd4b1ee50d1",
"metadata": {
"id": "6aac2d19-06d0-4005-916b-0bd4b1ee50d1"
},
"source": [
"- The dataset is saved as a tab-separated text file, which we can load into a pandas DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "da0ed4da-ac31-4e4d-8bdd-2153be4656a4",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "da0ed4da-ac31-4e4d-8bdd-2153be4656a4",
"outputId": "a16c5cde-d341-4887-a93f-baa9bec542ab"
},
"outputs": [
{
"data": {
"text/html": [
"
\n", " | Label | \n", "Text | \n", "
---|---|---|
0 | \n", "ham | \n", "Go until jurong point, crazy.. Available only ... | \n", "
1 | \n", "ham | \n", "Ok lar... Joking wif u oni... | \n", "
2 | \n", "spam | \n", "Free entry in 2 a wkly comp to win FA Cup fina... | \n", "
3 | \n", "ham | \n", "U dun say so early hor... U c already then say... | \n", "
4 | \n", "ham | \n", "Nah I don't think he goes to usf, he lives aro... | \n", "
... | \n", "... | \n", "... | \n", "
5567 | \n", "spam | \n", "This is the 2nd time we have tried 2 contact u... | \n", "
5568 | \n", "ham | \n", "Will ü b going to esplanade fr home? | \n", "
5569 | \n", "ham | \n", "Pity, * was in mood for that. So...any other s... | \n", "
5570 | \n", "ham | \n", "The guy did some bitching but I acted like i'd... | \n", "
5571 | \n", "ham | \n", "Rofl. Its true to its name | \n", "
5572 rows × 2 columns
\n", "