1953 lines
96 KiB
Plaintext
Raw Normal View History

2024-06-09 10:35:26 -05:00
{
"cells": [
{
"cell_type": "markdown",
"id": "12e91914-5f51-43fa-b65b-625e73b4d17b",
"metadata": {
"id": "12e91914-5f51-43fa-b65b-625e73b4d17b"
},
2024-06-09 10:35:26 -05:00
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "c2520ec3-722f-4f44-bdd1-885b13e7afbf",
"metadata": {
"id": "c2520ec3-722f-4f44-bdd1-885b13e7afbf"
},
2024-06-09 10:35:26 -05:00
"source": [
"# Chapter 7: Finetuning To Follow Instructions"
]
},
{
"cell_type": "code",
2024-06-10 08:20:12 -05:00
"execution_count": 1,
2024-06-09 10:35:26 -05:00
"id": "4e19327b-6c02-4881-ad02-9b6d3ec0b1b4",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4e19327b-6c02-4881-ad02-9b6d3ec0b1b4",
"outputId": "6560a9ce-8cbe-4c37-885b-e9c8c1946f69"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"matplotlib version: 3.7.1\n",
"tiktoken version: 0.7.0\n",
"torch version: 2.3.0+cu121\n",
"tqdm version: 4.66.4\n",
2024-06-10 08:20:12 -05:00
"tensorflow version: 2.15.0\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\n",
" \"matplotlib\", # Plotting library\n",
" \"tiktoken\", # Tokenizer\n",
" \"torch\", # Deep learning library\n",
" \"tqdm\", # Progress bar\n",
" \"tensorflow\", # For OpenAI's pretrained weights\n",
"]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"id": "264fca98-2f9a-4193-b435-2abfa3b4142f",
"metadata": {
"id": "264fca98-2f9a-4193-b435-2abfa3b4142f"
},
"source": [
"[figure]"
]
},
2024-06-09 10:35:26 -05:00
{
"cell_type": "markdown",
"id": "8bbc68e9-75b3-41f1-ac2c-e071c3cd0813",
"metadata": {
"id": "8bbc68e9-75b3-41f1-ac2c-e071c3cd0813"
},
2024-06-09 10:35:26 -05:00
"source": [
"## 7.1 Introduction to instruction finetuning"
]
},
2024-06-10 08:20:12 -05:00
{
"cell_type": "markdown",
"id": "53dba24a-6805-496c-9a7f-c75e2d3527ab",
"metadata": {
"id": "53dba24a-6805-496c-9a7f-c75e2d3527ab"
},
2024-06-10 08:20:12 -05:00
"source": [
"- In chapter 5, we saw that pretraining an LLM involves a training procedure where it learns to generate one word at a time\n",
"- Hence, a pretrained LLM is good at text completion, but it is not good at following instructions\n",
"- In this chapter, we teach the LLM to better follow instructions"
]
},
{
"cell_type": "markdown",
"id": "18dc0535-0904-44ed-beaf-9b678292ef35",
"metadata": {
"id": "18dc0535-0904-44ed-beaf-9b678292ef35"
},
2024-06-10 08:20:12 -05:00
"source": [
"[figure]"
2024-06-10 08:20:12 -05:00
]
},
{
"cell_type": "markdown",
"id": "b4698b23-12e0-4bd7-a140-ccb3dd71d4e8",
"metadata": {
"id": "b4698b23-12e0-4bd7-a140-ccb3dd71d4e8"
},
2024-06-10 08:20:12 -05:00
"source": [
"- An optional step after instruction finetuning is preference tuning, which refines the response style of an LLM; readers interested in preference tuning can find example code in the bonus materials: [../04_preference-tuning-with-dpo](../04_preference-tuning-with-dpo)\n",
"\n",
"- The topics covered in this chapter are summarized in the figure below\n",
"\n",
"[figure]"
2024-06-10 08:20:12 -05:00
]
},
2024-06-09 10:35:26 -05:00
{
"cell_type": "markdown",
"id": "5384f0cf-ef3c-4436-a5fa-59bd25649f86",
"metadata": {
"id": "5384f0cf-ef3c-4436-a5fa-59bd25649f86"
},
2024-06-09 10:35:26 -05:00
"source": [
"## 7.2 Preparing a dataset for supervised instruction finetuning"
]
},
2024-06-10 08:20:12 -05:00
{
"cell_type": "markdown",
"id": "f8b34ff8-619f-4e89-bd03-ce513269760d",
"metadata": {
"id": "f8b34ff8-619f-4e89-bd03-ce513269760d"
},
2024-06-10 08:20:12 -05:00
"source": [
"- We will work with an instruction dataset I prepared for this chapter"
]
},
2024-06-09 10:35:26 -05:00
{
"cell_type": "code",
2024-06-10 08:20:12 -05:00
"execution_count": 2,
2024-06-09 10:35:26 -05:00
"id": "0G3axLw6kY1N",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0G3axLw6kY1N",
"outputId": "c48ade8c-0d31-4efb-8246-6e6c51669dde"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-06-10 08:20:12 -05:00
"Number of entries: 1100\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"import json\n",
"import os\n",
"import urllib\n",
"\n",
"\n",
"def download_and_load_file(file_path, url):\n",
"\n",
" if not os.path.exists(file_path):\n",
" with urllib.request.urlopen(url) as response:\n",
" text_data = response.read().decode('utf-8')\n",
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
" file.write(text_data)\n",
" else:\n",
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" text_data = file.read()\n",
"\n",
" with open(file_path, \"r\") as file:\n",
" data = json.load(file)\n",
"\n",
" return data\n",
"\n",
"\n",
"file_path = \"instruction-data.json\"\n",
"url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json\"\n",
"\n",
"data = download_and_load_file(file_path, url)\n",
2024-06-10 08:20:12 -05:00
"print(\"Number of entries:\", len(data))"
]
},
{
"cell_type": "markdown",
"id": "d7af8176-4255-4e92-8c7d-998771733eb8",
"metadata": {
"id": "d7af8176-4255-4e92-8c7d-998771733eb8"
},
2024-06-10 08:20:12 -05:00
"source": [
"- Each item in the `data` list we loaded from the JSON file above is a dictionary in the following form:"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "code",
2024-06-10 08:20:12 -05:00
"execution_count": 3,
2024-06-09 10:35:26 -05:00
"id": "-LiuBMsHkzQV",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-LiuBMsHkzQV",
"outputId": "88fe5be1-da18-45b5-dbb5-abcbcc4558e5"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-06-10 08:20:12 -05:00
"Example entry:\n",
"\n",
" {'instruction': 'Identify the correct spelling of the following word.', 'input': 'Ocassion', 'output': \"The correct spelling is 'Occasion.'\"}\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
2024-06-10 08:20:12 -05:00
"print(\"Example entry:\\n\\n\", data[50])"
]
},
{
"cell_type": "markdown",
"id": "c5a32b34-485a-4816-a77a-da14f9fe6e46",
"metadata": {
"id": "c5a32b34-485a-4816-a77a-da14f9fe6e46"
},
2024-06-10 08:20:12 -05:00
"source": [
"- Note that the `'input'` field can be empty:"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "code",
2024-06-10 08:20:12 -05:00
"execution_count": 4,
2024-06-09 10:35:26 -05:00
"id": "uFInFxDDk2Je",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uFInFxDDk2Je",
"outputId": "a07ca278-0205-4ac4-b81e-54a513ece585"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-06-10 08:20:12 -05:00
"Another example entry:\n",
"\n",
" {'instruction': \"What is an antonym of 'complicated'?\", 'input': '', 'output': \"An antonym of 'complicated' is 'simple'.\"}\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
2024-06-10 08:20:12 -05:00
"print(\"Another example entry:\\n\\n\", data[999])"
]
},
{
"cell_type": "markdown",
"id": "f034799a-6575-45fd-98c9-9d1012d0fd58",
"metadata": {
"id": "f034799a-6575-45fd-98c9-9d1012d0fd58"
},
2024-06-10 08:20:12 -05:00
"source": [
"- Instruction finetuning is often referred to as \"supervised instruction finetuning\" because it involves training a model on a dataset where the input-output pairs are explicitly provided\n",
"- There are different ways to format the entries as inputs to the LLM; the figure below illustrates two example formats that were used for training the Alpaca (https://crfm.stanford.edu/2023/03/13/alpaca.html) and Phi-3 (https://arxiv.org/abs/2404.14219) LLMs, respectively"
]
},
{
"cell_type": "markdown",
"id": "dffa4f70-44d4-4be4-89a9-2159f4885b10",
"metadata": {
"id": "dffa4f70-44d4-4be4-89a9-2159f4885b10"
},
2024-06-10 08:20:12 -05:00
"source": [
"[figure]"
2024-06-10 08:20:12 -05:00
]
},
{
"cell_type": "markdown",
"id": "dd79a74e-befb-491c-be49-f777a6a5b6a6",
"metadata": {
"id": "dd79a74e-befb-491c-be49-f777a6a5b6a6"
},
2024-06-10 08:20:12 -05:00
"source": [
"- In this chapter, we use Alpaca-style prompt formatting, which was the original prompt template for instruction finetuning\n",
"- Below we format the input that we will pass as input to the LLM"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "code",
2024-06-10 08:20:12 -05:00
"execution_count": 5,
"id": "Jhk37nnJnkBh",
2024-06-09 10:35:26 -05:00
"metadata": {
2024-06-10 08:20:12 -05:00
"id": "Jhk37nnJnkBh"
2024-06-09 10:35:26 -05:00
},
"outputs": [],
"source": [
2024-06-10 08:20:12 -05:00
"def format_input(entry):\n",
" instruction_text = (\n",
" f\"Below is an instruction that describes a task. \"\n",
" f\"Write a response that appropriately completes the request.\"\n",
" f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
" )\n",
2024-06-09 10:35:26 -05:00
"\n",
2024-06-10 08:20:12 -05:00
" input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
"\n",
" return instruction_text + input_text"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "markdown",
"id": "011e78b4-e89a-4653-a2ee-7b2739ca04d6",
"metadata": {
"id": "011e78b4-e89a-4653-a2ee-7b2739ca04d6"
},
"source": [
"- A formatted response with input field looks like as shown below"
]
},
2024-06-09 10:35:26 -05:00
{
"cell_type": "code",
2024-06-10 08:20:12 -05:00
"execution_count": 6,
"id": "F9UQRfjzo4Js",
2024-06-09 10:35:26 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
2024-06-10 08:20:12 -05:00
"id": "F9UQRfjzo4Js",
"outputId": "f05669d2-13a8-4eb3-f549-dab83cec1e00"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-06-10 08:20:12 -05:00
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"Identify the correct spelling of the following word.\n",
"\n",
"### Input:\n",
"Ocassion\n",
"\n",
"### Response:\n",
"The correct spelling is 'Occasion.'\n"
]
}
],
"source": [
"model_input = format_input(data[50])\n",
"desired_response = f\"\\n\\n### Response:\\n{data[50]['output']}\"\n",
"\n",
"print(model_input + desired_response)"
]
},
{
"cell_type": "markdown",
"id": "4dc93ddf-431c-49c0-96f2-fb3a79c4d94c",
"metadata": {
"id": "4dc93ddf-431c-49c0-96f2-fb3a79c4d94c"
},
"source": [
"- Below is a formatted response without input field"
]
},
2024-06-10 08:20:12 -05:00
{
"cell_type": "code",
"execution_count": 7,
"id": "a3891fa9-f738-41cd-946c-80ef9a99c346",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "a3891fa9-f738-41cd-946c-80ef9a99c346",
"outputId": "b9550b1f-8b35-4b00-96d3-a1ce2b76daee"
},
2024-06-10 08:20:12 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"What is an antonym of 'complicated'?\n",
"\n",
"### Response:\n",
"An antonym of 'complicated' is 'simple'.\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
2024-06-10 08:20:12 -05:00
"model_input = format_input(data[999])\n",
"desired_response = f\"\\n\\n### Response:\\n{data[999]['output']}\"\n",
"\n",
"print(model_input + desired_response)"
]
},
{
"cell_type": "markdown",
"id": "4aa8afd5-2a21-49a5-90c3-6a03865a4771",
"metadata": {
"id": "4aa8afd5-2a21-49a5-90c3-6a03865a4771"
},
2024-06-10 08:20:12 -05:00
"source": [
"- Lastly, before we prepare the PyTorch data loaders in the next section, we divide the dataset into a training, validation, and test set"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "code",
"execution_count": 8,
2024-06-10 08:20:12 -05:00
"id": "aFZVopbIlNfx",
2024-06-09 10:35:26 -05:00
"metadata": {
2024-06-10 08:20:12 -05:00
"id": "aFZVopbIlNfx"
2024-06-09 10:35:26 -05:00
},
"outputs": [],
"source": [
2024-06-10 08:20:12 -05:00
"train_portion = int(len(data) * 0.85) # 85% for training\n",
"test_portion = int(len(data) * 0.1) # 10% for testing\n",
"val_portion = len(data) - train_portion - test_portion # Remaining 5% for validation\n",
2024-06-09 10:35:26 -05:00
"\n",
2024-06-10 08:20:12 -05:00
"train_data = data[:train_portion]\n",
"test_data = data[train_portion:train_portion + test_portion]\n",
"val_data = data[train_portion + test_portion:]"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "code",
"execution_count": 9,
2024-06-10 08:20:12 -05:00
"id": "-zf6oht6bIUQ",
2024-06-09 10:35:26 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
2024-06-10 08:20:12 -05:00
"id": "-zf6oht6bIUQ",
"outputId": "5a11a57f-2ce2-408f-e05a-a09cb661e49b"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-06-10 08:20:12 -05:00
"Training set length: 935\n",
"Validation set length: 55\n",
"Test set length: 110\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
2024-06-10 08:20:12 -05:00
"print(\"Training set length:\", len(train_data))\n",
"print(\"Validation set length:\", len(val_data))\n",
"print(\"Test set length:\", len(test_data))"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "markdown",
"id": "fcaaf606-f913-4445-8301-632ae10d387d",
"metadata": {
"id": "fcaaf606-f913-4445-8301-632ae10d387d"
},
2024-06-09 10:35:26 -05:00
"source": [
"## 7.3 Creating data loaders for an instruction dataset"
]
},
{
"cell_type": "markdown",
"id": "233f63bd-9755-4d07-8884-5e2e5345cf27",
"metadata": {
"id": "233f63bd-9755-4d07-8884-5e2e5345cf27"
},
"source": [
"[figure]"
]
},
{
"cell_type": "markdown",
"id": "b9af423f-aad9-4b3c-bea5-153021c04862",
"metadata": {
"id": "b9af423f-aad9-4b3c-bea5-153021c04862"
},
"source": [
"- First, we implement an `InstructionDataset` class that pre-tokenizes all inputs in the dataset, similar to the `SpamDataset` in chapter 6\n",
"\n",
"[figure]"
]
},
2024-06-09 10:35:26 -05:00
{
"cell_type": "code",
"execution_count": 10,
"id": "K6MWf0lhu8GP",
"metadata": {
"id": "K6MWf0lhu8GP"
2024-06-09 10:35:26 -05:00
},
"outputs": [],
2024-06-09 10:35:26 -05:00
"source": [
"import tiktoken\n",
"\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "adc29dc4-f1c7-4c71-937b-95119d6239bb",
"metadata": {
"id": "adc29dc4-f1c7-4c71-937b-95119d6239bb"
},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class InstructionDataset(Dataset):\n",
" def __init__(self, data, tokenizer):\n",
" self.data = data\n",
"\n",
" # Pre-tokenize texts\n",
" self.encoded_texts = []\n",
" for entry in data:\n",
" instruction_plus_input = format_input(entry)\n",
" response_text = f\"\\n\\n### Response:\\n{entry['output']}\"\n",
" full_text = instruction_plus_input + response_text\n",
" self.encoded_texts.append(\n",
" tokenizer.encode(full_text)\n",
2024-06-09 10:35:26 -05:00
" )\n",
"\n",
" def __getitem__(self, index):\n",
" return self.encoded_texts[index]\n",
"\n",
" def __len__(self):\n",
" return len(self.data)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ff24fe1a-5746-461c-ad3d-b6d84a1a7c96",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ff24fe1a-5746-461c-ad3d-b6d84a1a7c96",
"outputId": "7459dd6d-aaad-49c5-9c82-db9b50358c77"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[50256]\n"
]
}
],
"source": [
"print(tokenizer.encode(\"<|endoftext|>\", allowed_special={\"<|endoftext|>\"}))"
]
},
{
"cell_type": "code",
"execution_count": 13,
2024-06-09 10:35:26 -05:00
"id": "W2jvh-OP9MFV",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "W2jvh-OP9MFV",
"outputId": "b3f94569-8997-461b-909e-b469e0b3c089"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.1269)"
]
},
"execution_count": 13,
2024-06-09 10:35:26 -05:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Explain index masking\n",
"\n",
"targets = torch.tensor([0, 1])\n",
"inputs = torch.tensor(\n",
" [[-1., 1.],\n",
" [-0.5, 1.5]]\n",
")\n",
"\n",
"torch.nn.functional.cross_entropy(inputs, targets)"
]
},
{
"cell_type": "code",
"execution_count": 14,
2024-06-09 10:35:26 -05:00
"id": "nvVMuil89v9N",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nvVMuil89v9N",
"outputId": "5d9f0948-ddc2-4766-c2ba-c14ca550e9d1"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.7936)"
]
},
"execution_count": 14,
2024-06-09 10:35:26 -05:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"targets = torch.tensor([0, 1, 1])\n",
"inputs = torch.tensor(\n",
" [[-1., 1.],\n",
" [-0.5, 1.5],\n",
" [-0.5, 1.5]]\n",
")\n",
"torch.nn.functional.cross_entropy(inputs, targets)"
]
},
{
"cell_type": "code",
"execution_count": 15,
2024-06-09 10:35:26 -05:00
"id": "RTyB1vah9p56",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RTyB1vah9p56",
"outputId": "245a8257-d1a3-4e94-a062-07b820b71aed"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.1269)"
]
},
"execution_count": 15,
2024-06-09 10:35:26 -05:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"targets = torch.tensor([0, 1, -100])\n",
"inputs = torch.tensor(\n",
" [[-1., 1.],\n",
" [-0.5, 1.5],\n",
" [-0.5, 1.5]]\n",
")\n",
"torch.nn.functional.cross_entropy(inputs, targets)"
]
},
{
"cell_type": "code",
"execution_count": 16,
2024-06-09 10:35:26 -05:00
"id": "41ec6e2d-9eb2-4124-913e-d2af39be4cf2",
"metadata": {
"id": "41ec6e2d-9eb2-4124-913e-d2af39be4cf2"
},
2024-06-09 10:35:26 -05:00
"outputs": [],
"source": [
"def custom_collate_fn(\n",
" batch,\n",
2024-06-09 10:35:26 -05:00
" pad_token_id=50256,\n",
" ignore_index=-100,\n",
" allowed_max_length=None,\n",
2024-06-09 10:35:26 -05:00
" device=\"cpu\"\n",
"):\n",
" # Find the longest sequence in the batch\n",
" batch_max_length = max(len(item)+1 for item in batch)\n",
2024-06-09 10:35:26 -05:00
"\n",
" # Pad and prepare inputs and targets\n",
" inputs_lst, targets_lst = [], []\n",
" for item in batch:\n",
" # Add an <|endoftext|> token\n",
" item += [pad_token_id]\n",
2024-06-09 10:35:26 -05:00
" # Pad sequences to max_length\n",
" padded = item + [pad_token_id] * (batch_max_length - len(item))\n",
" inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs\n",
" targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets\n",
"\n",
" # Replace all but the first padding tokens in targets by ignore_index\n",
" mask = targets == pad_token_id\n",
" indices = torch.nonzero(mask).squeeze()\n",
" if indices.numel() > 1:\n",
" targets[indices[1:]] = ignore_index\n",
"\n",
" # Optionally truncate to maximum sequence length\n",
" if allowed_max_length is not None:\n",
" inputs = inputs[:allowed_max_length]\n",
" targets = targets[:allowed_max_length]\n",
"\n",
" inputs_lst.append(inputs)\n",
" targets_lst.append(targets)\n",
"\n",
" inputs_tensor = torch.stack(inputs_lst).to(device)\n",
" targets_tensor = torch.stack(targets_lst).to(device)\n",
"\n",
" return inputs_tensor, targets_tensor"
]
},
{
"cell_type": "code",
"execution_count": 17,
2024-06-09 10:35:26 -05:00
"id": "cdf5eec4-9ebe-4be0-9fca-9a47bee88fdc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cdf5eec4-9ebe-4be0-9fca-9a47bee88fdc",
"outputId": "0484b12b-b0d6-4329-d6d3-7a2b05fbaf8e"
},
2024-06-09 10:35:26 -05:00
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[ 0, 1, 2, 3, 4, 5, 6],\n",
" [ 7, 8, 9, 50256, 50256, 50256, 50256]]),\n",
" tensor([[ 1, 2, 3, 4, 5, 6, 50256],\n",
" [ 8, 9, 50256, -100, -100, -100, -100]]))"
2024-06-09 10:35:26 -05:00
]
},
"execution_count": 17,
2024-06-09 10:35:26 -05:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs_1 = [0, 1, 2, 3, 4, 5, 6]\n",
2024-06-09 10:35:26 -05:00
"inputs_2 = [7, 8, 9]\n",
"\n",
"batch = (\n",
" inputs_1,\n",
" inputs_2\n",
")\n",
"\n",
"custom_collate_fn(batch)"
]
},
{
"cell_type": "code",
"execution_count": 18,
2024-06-09 10:35:26 -05:00
"id": "etpqqWh8phKc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "etpqqWh8phKc",
"outputId": "f2f902d2-d51a-4a62-a2ae-b1f52037c92f"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device: cuda\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"from functools import partial\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(\"Device:\", device)\n",
"\n",
"customized_collate_fn = partial(custom_collate_fn, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 22,
2024-06-09 10:35:26 -05:00
"id": "BtWkgir6Hlpe",
"metadata": {
"id": "BtWkgir6Hlpe"
2024-06-09 10:35:26 -05:00
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"num_workers = 0\n",
"batch_size = 8\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"train_dataset = InstructionDataset(train_data, tokenizer)\n",
"train_loader = DataLoader(\n",
" train_dataset,\n",
" batch_size=batch_size,\n",
" collate_fn=customized_collate_fn,\n",
" shuffle=True,\n",
" drop_last=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 23,
2024-06-09 10:35:26 -05:00
"id": "1d097dc8-ad34-4f05-b435-e4147965f532",
"metadata": {
"id": "1d097dc8-ad34-4f05-b435-e4147965f532"
},
"outputs": [],
"source": [
"val_dataset = InstructionDataset(val_data, tokenizer)\n",
"val_loader = DataLoader(\n",
" val_dataset,\n",
" batch_size=batch_size,\n",
" collate_fn=customized_collate_fn,\n",
" shuffle=False,\n",
" drop_last=False\n",
")\n",
"\n",
"test_dataset = InstructionDataset(test_data, tokenizer)\n",
"test_loader = DataLoader(\n",
" test_dataset,\n",
" batch_size=batch_size,\n",
" collate_fn=customized_collate_fn,\n",
" shuffle=False,\n",
" drop_last=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 24,
2024-06-09 10:35:26 -05:00
"id": "GGs1AI3vHpnX",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GGs1AI3vHpnX",
"outputId": "df95971c-10ca-49e8-9823-d63bc5b6a3fc"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loader:\n",
"torch.Size([8, 61]) torch.Size([8, 61])\n",
"torch.Size([8, 76]) torch.Size([8, 76])\n",
"torch.Size([8, 73]) torch.Size([8, 73])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 65]) torch.Size([8, 65])\n",
"torch.Size([8, 72]) torch.Size([8, 72])\n",
"torch.Size([8, 80]) torch.Size([8, 80])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 67]) torch.Size([8, 67])\n",
"torch.Size([8, 62]) torch.Size([8, 62])\n",
"torch.Size([8, 75]) torch.Size([8, 75])\n",
"torch.Size([8, 62]) torch.Size([8, 62])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 67]) torch.Size([8, 67])\n",
"torch.Size([8, 77]) torch.Size([8, 77])\n",
"torch.Size([8, 69]) torch.Size([8, 69])\n",
"torch.Size([8, 79]) torch.Size([8, 79])\n",
"torch.Size([8, 71]) torch.Size([8, 71])\n",
"torch.Size([8, 66]) torch.Size([8, 66])\n",
"torch.Size([8, 83]) torch.Size([8, 83])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 80]) torch.Size([8, 80])\n",
"torch.Size([8, 71]) torch.Size([8, 71])\n",
"torch.Size([8, 69]) torch.Size([8, 69])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 65]) torch.Size([8, 65])\n",
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 60]) torch.Size([8, 60])\n",
"torch.Size([8, 59]) torch.Size([8, 59])\n",
"torch.Size([8, 69]) torch.Size([8, 69])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 63]) torch.Size([8, 63])\n",
"torch.Size([8, 65]) torch.Size([8, 65])\n",
"torch.Size([8, 76]) torch.Size([8, 76])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 66]) torch.Size([8, 66])\n",
"torch.Size([8, 71]) torch.Size([8, 71])\n",
"torch.Size([8, 91]) torch.Size([8, 91])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 65]) torch.Size([8, 65])\n",
"torch.Size([8, 64]) torch.Size([8, 64])\n",
"torch.Size([8, 67]) torch.Size([8, 67])\n",
"torch.Size([8, 66]) torch.Size([8, 66])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 64]) torch.Size([8, 64])\n",
"torch.Size([8, 65]) torch.Size([8, 65])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 75]) torch.Size([8, 75])\n",
"torch.Size([8, 89]) torch.Size([8, 89])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 59]) torch.Size([8, 59])\n",
"torch.Size([8, 88]) torch.Size([8, 88])\n",
"torch.Size([8, 83]) torch.Size([8, 83])\n",
"torch.Size([8, 83]) torch.Size([8, 83])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 70]) torch.Size([8, 70])\n",
"torch.Size([8, 65]) torch.Size([8, 65])\n",
"torch.Size([8, 74]) torch.Size([8, 74])\n",
"torch.Size([8, 76]) torch.Size([8, 76])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 67]) torch.Size([8, 67])\n",
"torch.Size([8, 75]) torch.Size([8, 75])\n",
"torch.Size([8, 83]) torch.Size([8, 83])\n",
"torch.Size([8, 69]) torch.Size([8, 69])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 67]) torch.Size([8, 67])\n",
"torch.Size([8, 60]) torch.Size([8, 60])\n",
"torch.Size([8, 60]) torch.Size([8, 60])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 66]) torch.Size([8, 66])\n",
"torch.Size([8, 80]) torch.Size([8, 80])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 71]) torch.Size([8, 71])\n",
"torch.Size([8, 61]) torch.Size([8, 61])\n",
"torch.Size([8, 58]) torch.Size([8, 58])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 71]) torch.Size([8, 71])\n",
"torch.Size([8, 67]) torch.Size([8, 67])\n",
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 63]) torch.Size([8, 63])\n",
"torch.Size([8, 87]) torch.Size([8, 87])\n",
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 64]) torch.Size([8, 64])\n",
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 71]) torch.Size([8, 71])\n",
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 71]) torch.Size([8, 71])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 61]) torch.Size([8, 61])\n",
"torch.Size([8, 65]) torch.Size([8, 65])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 67]) torch.Size([8, 67])\n",
"torch.Size([8, 65]) torch.Size([8, 65])\n",
"torch.Size([8, 64]) torch.Size([8, 64])\n",
"torch.Size([8, 60]) torch.Size([8, 60])\n",
"torch.Size([8, 72]) torch.Size([8, 72])\n",
"torch.Size([8, 64]) torch.Size([8, 64])\n",
"torch.Size([8, 70]) torch.Size([8, 70])\n",
"torch.Size([8, 57]) torch.Size([8, 57])\n",
"torch.Size([8, 72]) torch.Size([8, 72])\n",
"torch.Size([8, 64]) torch.Size([8, 64])\n",
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 62]) torch.Size([8, 62])\n",
"torch.Size([8, 74]) torch.Size([8, 74])\n",
"torch.Size([8, 80]) torch.Size([8, 80])\n",
"torch.Size([8, 68]) torch.Size([8, 68])\n",
"torch.Size([8, 70]) torch.Size([8, 70])\n",
"torch.Size([8, 91]) torch.Size([8, 91])\n",
"torch.Size([8, 61]) torch.Size([8, 61])\n",
"torch.Size([8, 66]) torch.Size([8, 66])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 80]) torch.Size([8, 80])\n",
"torch.Size([8, 81]) torch.Size([8, 81])\n",
"torch.Size([8, 74]) torch.Size([8, 74])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 82]) torch.Size([8, 82])\n",
"torch.Size([8, 63]) torch.Size([8, 63])\n",
"torch.Size([8, 83]) torch.Size([8, 83])\n",
"torch.Size([8, 68]) torch.Size([8, 68])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 67]) torch.Size([8, 67])\n",
"torch.Size([8, 77]) torch.Size([8, 77])\n",
"torch.Size([8, 91]) torch.Size([8, 91])\n",
"torch.Size([8, 64]) torch.Size([8, 64])\n",
"torch.Size([8, 61]) torch.Size([8, 61])\n",
"torch.Size([8, 75]) torch.Size([8, 75])\n",
"torch.Size([8, 64]) torch.Size([8, 64])\n",
"torch.Size([8, 66]) torch.Size([8, 66])\n",
"torch.Size([8, 78]) torch.Size([8, 78])\n",
"torch.Size([8, 66]) torch.Size([8, 66])\n",
"torch.Size([8, 64]) torch.Size([8, 64])\n",
"torch.Size([8, 83]) torch.Size([8, 83])\n",
2024-06-09 10:35:26 -05:00
"torch.Size([8, 66]) torch.Size([8, 66])\n",
"torch.Size([8, 74]) torch.Size([8, 74])\n",
"torch.Size([8, 69]) torch.Size([8, 69])\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"print(\"Train loader:\")\n",
"for x, y in train_loader:\n",
" print(x.shape, y.shape)"
]
},
{
"cell_type": "code",
"execution_count": 25,
2024-06-09 10:35:26 -05:00
"id": "21b8fd02-014f-4481-9b71-5bfee8f9dfcd",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "21b8fd02-014f-4481-9b71-5bfee8f9dfcd",
"outputId": "cacf7f22-ec66-4350-8db4-890e7e86718f"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 30003, 6525, 262, 6827, 1262, 257, 985, 576, 13, 198, 198, 21017, 23412, 25, 198, 464, 5156, 318, 845, 13779, 13, 198, 198, 21017, 18261, 25, 198, 464, 5156, 318, 355, 13779, 355, 257, 4936, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, "
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"for i in x[0]:\n",
" print(i.item(), end=\", \")"
]
},
{
"cell_type": "code",
"execution_count": 26,
2024-06-09 10:35:26 -05:00
"id": "51649ab4-1a7e-4a9e-92c5-950a24fde211",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "51649ab4-1a7e-4a9e-92c5-950a24fde211",
"outputId": "486fda24-80d4-4bc2-f253-2476f93cd146"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 30003, 6525, 262, 6827, 1262, 257, 985, 576, 13, 198, 198, 21017, 23412, 25, 198, 464, 5156, 318, 845, 13779, 13, 198, 198, 21017, 18261, 25, 198, 464, 5156, 318, 355, 13779, 355, 257, 4936, 13, 50256, -100, -100, -100, -100, -100, -100, -100, -100, -100, "
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"for i in y[0]:\n",
" print(i.item(), end=\", \")"
]
},
{
"cell_type": "markdown",
"id": "d6aad445-8f19-4238-b9bf-db80767fb91a",
"metadata": {
"id": "d6aad445-8f19-4238-b9bf-db80767fb91a"
},
"source": [
"## 7.4 Loading a pretrained LLM"
]
},
{
"cell_type": "code",
"execution_count": 27,
2024-06-09 10:35:26 -05:00
"id": "0d249d67-5eba-414e-9bd2-972ebf01329d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0d249d67-5eba-414e-9bd2-972ebf01329d",
"outputId": "ca78e098-c253-4bbe-ebb5-6fd018d8e037"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stderr",
2024-06-09 10:35:26 -05:00
"output_type": "stream",
"text": [
"checkpoint: 100%|██████████| 77.0/77.0 [00:00<00:00, 116kiB/s]\n",
"encoder.json: 100%|██████████| 1.04M/1.04M [00:02<00:00, 509kiB/s]\n",
"hparams.json: 100%|██████████| 91.0/91.0 [00:00<00:00, 138kiB/s]\n",
"model.ckpt.data-00000-of-00001: 100%|██████████| 1.42G/1.42G [02:49<00:00, 8.38MiB/s]\n",
"model.ckpt.index: 100%|██████████| 10.4k/10.4k [00:00<00:00, 13.8MiB/s]\n",
"model.ckpt.meta: 100%|██████████| 927k/927k [00:02<00:00, 454kiB/s]\n",
"vocab.bpe: 100%|██████████| 456k/456k [00:01<00:00, 321kiB/s]\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"from gpt_download import download_and_load_gpt2\n",
"from previous_chapters import GPTModel, load_weights_into_gpt\n",
"\n",
"\n",
"BASE_CONFIG = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 1024, # Context length\n",
" \"drop_rate\": 0.0, # Dropout rate\n",
" \"qkv_bias\": True # Query-key-value bias\n",
"}\n",
"\n",
"model_configs = {\n",
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"CHOOSE_MODEL = \"gpt2-medium (355M)\"\n",
"\n",
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
"\n",
"model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n",
"settings, params = download_and_load_gpt2(model_size=model_size, models_dir=\"gpt2\")\n",
"\n",
"model = GPTModel(BASE_CONFIG)\n",
"load_weights_into_gpt(model, params)\n",
"model.eval();"
]
},
{
"cell_type": "code",
"execution_count": 28,
2024-06-09 10:35:26 -05:00
"id": "7bd32b7c-5b44-4d25-a09f-46836802ca74",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7bd32b7c-5b44-4d25-a09f-46836802ca74",
"outputId": "e5dbf217-591c-4c2e-9ec2-ef5365fa269e"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"Convert the active sentence to passive: 'The chef cooks the meal every day.'\n",
"\n",
"### Response:\n",
"\n",
"The chef cooks the meal every day.\n",
"\n",
"### Instruction:\n",
"\n",
"Convert the active sentence to passive: 'The chef cooks the\n"
]
}
],
"source": [
"from previous_chapters import (\n",
" generate,\n",
" text_to_token_ids,\n",
" token_ids_to_text\n",
")\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"token_ids = generate(\n",
" model=model,\n",
" idx=text_to_token_ids(format_input(val_data[0]), tokenizer),\n",
" max_new_tokens=35,\n",
" context_size=BASE_CONFIG[\"context_length\"],\n",
")\n",
"\n",
"print(token_ids_to_text(token_ids, tokenizer))"
]
},
{
"cell_type": "markdown",
"id": "70d27b9d-a942-4cf5-b797-848c5f01e723",
"metadata": {
"id": "70d27b9d-a942-4cf5-b797-848c5f01e723"
},
"source": [
"## 7.5 Finetuning the LLM on instruction data"
]
},
{
"cell_type": "code",
"execution_count": 29,
2024-06-09 10:35:26 -05:00
"id": "65444865-df87-4d98-9faf-875e1c4be860",
"metadata": {
"id": "65444865-df87-4d98-9faf-875e1c4be860"
},
"outputs": [],
"source": [
"from previous_chapters import (\n",
" calc_loss_loader,\n",
" train_model_simple\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 30,
2024-06-09 10:35:26 -05:00
"id": "d99fc6f8-63b2-43da-adbb-a7b6b92c8dd5",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d99fc6f8-63b2-43da-adbb-a7b6b92c8dd5",
"outputId": "a4d82a24-f16e-4cf7-ebe6-0bff051517a1"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training loss: 3.8259091854095457\n",
"Validation loss: 3.7619335651397705\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"model.to(device)\n",
"\n",
"torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n",
"\n",
"with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet\n",
" train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)\n",
" val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)\n",
"\n",
"print(\"Training loss:\", train_loss)\n",
"print(\"Validation loss:\", val_loss)"
]
},
{
"cell_type": "markdown",
"id": "db4b57fb-e689-4550-931c-6d34a932487c",
"metadata": {
"id": "db4b57fb-e689-4550-931c-6d34a932487c"
},
2024-06-09 10:35:26 -05:00
"source": [
"- Runtimes:\n",
"\n",
"<div style=\"text-align: left;\">\n",
" \n",
"| Model | Platform | Runtime |\n",
"|--------------------|-----------------------|----------------|\n",
"| gpt2-medium (355M) | CPU (M3 MacBook Air) | 23.67 minutes |\n",
"| gpt2-medium (355M) | GPU (L4) | 2.98 minutes |\n",
2024-06-09 10:35:26 -05:00
"| gpt2-medium (355M) | GPU (A100) | 1.29 minutes |\n",
"| gpt2-small (124M) | CPU (M3 MacBook Air) | 8.61 minutes |\n",
2024-06-09 10:35:26 -05:00
"| gpt2-small (124M) | GPU (A100) | 0.59 minutes |\n",
"\n",
"</div>\n",
"\n",
"- This notebook was run with the `\"gpt2-medium (355M)\"` model"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "code",
"execution_count": 31,
2024-06-09 10:35:26 -05:00
"id": "78bcf83a-1fff-4540-97c1-765c4016d5e3",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "78bcf83a-1fff-4540-97c1-765c4016d5e3",
"outputId": "285ca27c-019f-4c2b-e130-8c46d2e7df53"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ep 1 (Step 000000): Train loss 2.637, Val loss 2.626\n",
"Ep 1 (Step 000005): Train loss 1.174, Val loss 1.102\n",
"Ep 1 (Step 000010): Train loss 0.872, Val loss 0.944\n",
"Ep 1 (Step 000015): Train loss 0.857, Val loss 0.906\n",
"Ep 1 (Step 000020): Train loss 0.776, Val loss 0.881\n",
"Ep 1 (Step 000025): Train loss 0.754, Val loss 0.859\n",
"Ep 1 (Step 000030): Train loss 0.799, Val loss 0.836\n",
"Ep 1 (Step 000035): Train loss 0.714, Val loss 0.808\n",
"Ep 1 (Step 000040): Train loss 0.672, Val loss 0.806\n",
"Ep 1 (Step 000045): Train loss 0.633, Val loss 0.789\n",
"Ep 1 (Step 000050): Train loss 0.663, Val loss 0.783\n",
"Ep 1 (Step 000055): Train loss 0.760, Val loss 0.763\n",
"Ep 1 (Step 000060): Train loss 0.719, Val loss 0.743\n",
"Ep 1 (Step 000065): Train loss 0.653, Val loss 0.735\n",
"Ep 1 (Step 000070): Train loss 0.532, Val loss 0.729\n",
"Ep 1 (Step 000075): Train loss 0.569, Val loss 0.728\n",
"Ep 1 (Step 000080): Train loss 0.605, Val loss 0.725\n",
"Ep 1 (Step 000085): Train loss 0.509, Val loss 0.709\n",
"Ep 1 (Step 000090): Train loss 0.562, Val loss 0.691\n",
"Ep 1 (Step 000095): Train loss 0.500, Val loss 0.681\n",
"Ep 1 (Step 000100): Train loss 0.503, Val loss 0.677\n",
"Ep 1 (Step 000105): Train loss 0.564, Val loss 0.670\n",
"Ep 1 (Step 000110): Train loss 0.555, Val loss 0.666\n",
"Ep 1 (Step 000115): Train loss 0.508, Val loss 0.664\n",
2024-06-09 10:35:26 -05:00
"Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Convert the active sentence to passive: 'The chef cooks the meal every day.' ### Response: The meal is prepared every day by the chef.<|endoftext|>The following is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Convert the active sentence to passive:\n",
"Ep 2 (Step 000120): Train loss 0.435, Val loss 0.672\n",
2024-06-09 10:35:26 -05:00
"Ep 2 (Step 000125): Train loss 0.451, Val loss 0.687\n",
"Ep 2 (Step 000130): Train loss 0.447, Val loss 0.683\n",
"Ep 2 (Step 000135): Train loss 0.405, Val loss 0.682\n",
"Ep 2 (Step 000140): Train loss 0.409, Val loss 0.681\n",
"Ep 2 (Step 000145): Train loss 0.369, Val loss 0.680\n",
"Ep 2 (Step 000150): Train loss 0.382, Val loss 0.675\n",
"Ep 2 (Step 000155): Train loss 0.413, Val loss 0.675\n",
"Ep 2 (Step 000160): Train loss 0.415, Val loss 0.683\n",
"Ep 2 (Step 000165): Train loss 0.379, Val loss 0.686\n",
"Ep 2 (Step 000170): Train loss 0.323, Val loss 0.681\n",
"Ep 2 (Step 000175): Train loss 0.337, Val loss 0.669\n",
"Ep 2 (Step 000180): Train loss 0.392, Val loss 0.657\n",
"Ep 2 (Step 000185): Train loss 0.415, Val loss 0.657\n",
"Ep 2 (Step 000190): Train loss 0.340, Val loss 0.648\n",
"Ep 2 (Step 000195): Train loss 0.329, Val loss 0.635\n",
"Ep 2 (Step 000200): Train loss 0.310, Val loss 0.635\n",
"Ep 2 (Step 000205): Train loss 0.352, Val loss 0.631\n",
"Ep 2 (Step 000210): Train loss 0.367, Val loss 0.630\n",
"Ep 2 (Step 000215): Train loss 0.396, Val loss 0.634\n",
"Ep 2 (Step 000220): Train loss 0.300, Val loss 0.647\n",
"Ep 2 (Step 000225): Train loss 0.347, Val loss 0.660\n",
"Ep 2 (Step 000230): Train loss 0.294, Val loss 0.655\n",
"Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Convert the active sentence to passive: 'The chef cooks the meal every day.' ### Response: The meal is cooked every day by the chef.<|endoftext|>The following is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: What is the capital of the United Kingdom\n",
"Ep 3 (Step 000235): Train loss 0.328, Val loss 0.661\n",
"Ep 3 (Step 000240): Train loss 0.280, Val loss 0.692\n",
"Ep 3 (Step 000245): Train loss 0.274, Val loss 0.702\n",
"Ep 3 (Step 000250): Train loss 0.248, Val loss 0.691\n",
"Ep 3 (Step 000255): Train loss 0.275, Val loss 0.680\n",
"Ep 3 (Step 000260): Train loss 0.266, Val loss 0.683\n",
"Ep 3 (Step 000265): Train loss 0.274, Val loss 0.701\n",
"Ep 3 (Step 000270): Train loss 0.280, Val loss 0.715\n",
"Ep 3 (Step 000275): Train loss 0.276, Val loss 0.705\n",
"Ep 3 (Step 000280): Train loss 0.296, Val loss 0.710\n",
"Ep 3 (Step 000285): Train loss 0.294, Val loss 0.714\n",
"Ep 3 (Step 000290): Train loss 0.287, Val loss 0.717\n",
"Ep 3 (Step 000295): Train loss 0.267, Val loss 0.711\n",
"Ep 3 (Step 000300): Train loss 0.271, Val loss 0.694\n",
"Ep 3 (Step 000305): Train loss 0.277, Val loss 0.686\n",
"Ep 3 (Step 000310): Train loss 0.276, Val loss 0.689\n",
"Ep 3 (Step 000315): Train loss 0.238, Val loss 0.688\n",
"Ep 3 (Step 000320): Train loss 0.255, Val loss 0.691\n",
"Ep 3 (Step 000325): Train loss 0.235, Val loss 0.693\n",
"Ep 3 (Step 000330): Train loss 0.233, Val loss 0.696\n",
"Ep 3 (Step 000335): Train loss 0.224, Val loss 0.698\n",
"Ep 3 (Step 000340): Train loss 0.243, Val loss 0.687\n",
"Ep 3 (Step 000345): Train loss 0.244, Val loss 0.675\n",
2024-06-09 10:35:26 -05:00
"Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Convert the active sentence to passive: 'The chef cooks the meal every day.' ### Response: The chef cooks the meal every day.<|endoftext|>The following is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: What is the capital of the United Kingdom? \n",
"Training completed in 2.98 minutes.\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"import time\n",
"\n",
"start_time = time.time()\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1)\n",
"\n",
"num_epochs = 3\n",
"\n",
"train_losses, val_losses, tokens_seen = train_model_simple(\n",
" model, train_loader, val_loader, optimizer, device,\n",
" num_epochs=num_epochs, eval_freq=5, eval_iter=5,\n",
" start_context=format_input(val_data[0]), tokenizer=tokenizer\n",
")\n",
"\n",
"end_time = time.time()\n",
"execution_time_minutes = (end_time - start_time) / 60\n",
"print(f\"Training completed in {execution_time_minutes:.2f} minutes.\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
2024-06-09 10:35:26 -05:00
"id": "1Vdh7jmHI1we",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 308
2024-06-09 10:35:26 -05:00
},
"id": "1Vdh7jmHI1we",
"outputId": "475faf7f-13e6-4168-84f2-3eb3897ffd73"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAegAAAEiCAYAAAAyI0HeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABhEklEQVR4nO3deVwU9f8H8Nfuwl7Act+XoIiIyKUQYGpJopmJZZr5TS2PLMz8WWZ+K89vaWlqpallSqWpaV55452ItygqoiKCB4dynwu7+/n9MbK4AnIt7ILv5+MxD3ZnPjPznll23vuZz8xneIwxBkIIIYToFb6uAyCEEEJIdZSgCSGEED1ECZoQQgjRQ5SgCSGEED1ECZoQQgjRQ5SgCSGEED1ECZoQQgjRQ5SgCSGEED1ECZoQQgjRQ5SgCWkDbt++DR6Ph/j4eF2HQgjREkrQhOgJHo/31GHWrFm6DpEQ0oIMdB0AIYSTnp6ufr1x40bMmDEDSUlJ6nHGxsa6CIsQoiNUgyZET9jZ2akHU1NT8Hg89XsbGxssWrQITk5OEIlE8PPzw969e2tdllKpxLvvvotOnTohLS0NALB9+3YEBARALBbD3d0ds2fPhkKhUM/D4/GwatUqDB48GFKpFB4eHtixY4d6em5uLkaMGAFra2tIJBJ4eHhgzZo1tcawefNm+Pj4QCKRwNLSEuHh4SguLlZPX7VqFby8vCAWi9GpUyf89NNPGvPfuXMHQ4cOhZmZGSwsLDBo0CDcvn1bPX306NGIjIzEwoULYW9vD0tLS0RFRaGioqLe+5wQvcYIIXpnzZo1zNTUVP1+0aJFTCaTsfXr17Nr166xTz/9lBkaGrLr168zxhhLSUlhANiFCxdYWVkZGzx4MPP392dZWVmMMcaOHTvGZDIZi46OZsnJyWz//v2sXbt2bNasWep1AGBOTk7szz//ZDdu3GCTJk1ixsbGLDs7mzHGWFRUFPPz82NnzpxhKSkpLCYmhu3YsaPG+O/fv88MDAzYokWLWEpKCrt06RJbtmwZKywsZIwxtnbtWmZvb8/+/vtvduvWLfb3338zCwsLFh0dzRhjrLy8nHl5ebF3332XXbp0iV29epW99dZbzNPTk8nlcsYYY6NGjWIymYxNmDCBJSYmsn/++YdJpVL2888/a/fDIERHKEETooeeTNAODg7sq6++0ijTvXt39sEHHzDGqhL0v//+y/r06cN69OjB8vLy1GX79OnDvv76a435//jjD2Zvb69+D4B98cUX6vdFRUUMANuzZw9jjLGBAweyd955p17xnzt3jgFgt2/frnF6+/bt2Z9//qkxbu7cuSwkJEQdm6enJ1OpVOrpcrmcSSQStm/fPsYYl6BdXV2ZQqFQl3njjTfYsGHD6hUjIfqO2qAJ0XMFBQW4f/8+wsLCNMaHhYXh4sWLGuOGDx8OJycnHDp0CBKJRD3+4sWLiI2NxVdffaUep1QqUVZWhpKSEkilUgBA165d1dONjIwgk8mQlZUFAHj//ffx+uuv4/z58+jbty8iIyMRGhpaY8y+vr7o06cPfHx8EBERgb59+2LIkCEwNzdHcXExkpOTMWbMGIwbN049j0KhgKmpqTremzdvwsTERGO5ZWVlSE5OVr/39vaGQCBQv7e3t0dCQsJT9iYhrQclaELakJdffhlr165FXFwcXnzxRfX4oqIizJ49G6+99lq1ecRisfq1oaGhxjQejweVSgUA6N+/P1JTU7F7927ExMSgT58+iIqKwsKFC6stUyAQICYmBidOnMD+/fvx448/4vPPP8epU6fUPwZ++eUXBAcHV5uvMt7AwECsW7eu2rKtra3rFS8hrR0laEL0nEwmg4ODA2JjY9GrVy/1+NjYWAQFBWmUff/999GlSxe8+uqr2LVrl7p8QEAAkpKS0KFDhybFYm1tjVGjRmHUqFF4/vnnMXXq1BoTNMAly7CwMISFhWHGjBlwdXXF1q1bMWXKFDg4OODWrVsYMWJEjfMGBARg48aNsLGxgUwma1LMhLRWlKAJaQWmTp2KmTNnon379vDz88OaNWsQHx9fYw3zww8/hFKpxCuvvII9e/agR48emDFjBl555RW4uLhgyJAh4PP5uHjxIi5fvoz//e9/9YphxowZCAwMhLe3N+RyOXbu3AkvL68ay546dQoHDx5E3759YWNjg1OnTuHBgwfq8rNnz8akSZNgamqKfv36QS6X4+zZs8jNzcWUKVMwYsQILFiwAIMGDcKcOXPg5OSE1NRUbNmyBZ9++imcnJwavzMJaSUoQRPSCkyaNAn5+fn4+OOPkZWVhc6dO2PHjh3w8PCosfzkyZOhUqnw8ssvY+/evYiIiMDOnTsxZ84cfPPNNzA0NESnTp0wduzYescgFAoxffp03L59GxKJBM8//zw2bNhQY1mZTIZjx45hyZIlKCgogKurK7777jv0798fADB27FhIpVIsWLAAU6dOhZGREXx8fDB58mQAgFQqxbFjxzBt2jS89tprKCwshKOjI/r06UM1avLM4DHGmK6DIIQQQogm6qiEEEII0UOUoAkhhBA9RAmaEEII0UOUoAkhhBA9RAmaEEII0UOUoAkhhBA9RAm6kZYtW4Z27dpBLBYjODgYp0+fbrZ1zZs3D927d4eJiQlsbGwQGRmp8ZxgAOjduzd4PJ7GMGHCBI0yaWlpGDBgAKRSKWxsbDB16lSNxw0CwJEjRxAQEACRSIQOHTogOjq6WjwN2fZZs2ZVi6tTp07q6WVlZYiKioKlpSWMjY3x+uuvIzMzU+dxt2vXrlrcPB4PUVFRAPRrfx87dgwDBw6Eg4MDeDwetm3bpjGdMYYZM2bA3t4eEokE4eHhuHHjhkaZnJwcjBgxAjKZDGZmZhgzZgyKioo0yly6dAnPP/88xGIxnJ2d8e2331aLddOmTejUqRPEYjF8fHywe/fup8YSGBiIF198scbYKyoqMG3aNPj4+MDIyAgODg4YOXIk7t+/r7HMmj6r+fPnN2vsIpEI1tbWsLW1rXGfjx49ulpM/fr10/t9DqDG/3sej4cFCxbodJ8bGhpCJpPB2Ni41uOgPh1P6hNLnXT4oI5Wa8OGDUwoFLLVq1ezK1eusHHjxjEzMzOWmZnZLOuLiIhga9asYZcvX2bx8fHs5ZdfZi4uLqyoqEhdplevXmzcuHEsPT1dPeTn56unKxQK1qVLFxYeHs4uXLjAdu/ezaysrNj06dPVZW7dusWkUimbMmUKu3r1Kvvxxx+ZQCBge/fubfS2z5w5k3l7e2vE9eDBA/X0CRMmMGdnZ3bw4EF29uxZ9txzz7HQ0FCdx52VlaURc0xMDAPADh8+rHf7e/fu3ezzzz9nW7ZsYQDY1q1bNbZl/vz5zNTUlG3bto1dvHiRvfrqq8zNzY2Vlpaqy/Tr14/5+vqykydPsn///Zd16NCBDR8+XD09Pz+f2drashEjRrDLly+z9evXM4lEwlauXKkuExsbywQCAfv222/Z1atX2RdffMEMDQ1ZQkJCrbEEBwczU1NTtmHDhmqx5+XlsfDwcLZx40Z27do1FhcXx4KCglhgYKDG9rm6urI5c+ZofBaPfzeaI/Zly5YxDw8PZmNjU+M+HzVqFOvXr59GTDk5ORpl9HGfM8Y0Yk5PT2erV69mPB6PJScn63Sfh4aGMj8/P+bo6MhOnTpV43FQn44ndcVSH5SgGyEoKIhFRUWp3yuVSubg4MDmzZvXIuvPyspiANjRo0fV43r16sU++uijWufZvXs34/P5LCMjQz1u+fLlTCaTqZ+v++mnnzJvb2+N+YYNG8YiIiLU7xu67TNnzmS+vr41TsvLy2OGhoZs06ZN6nGJiYkMAIuLi9Np3E/66KOPWPv27dWPP9TX/f3kAVelUjE7Ozu2YMEC9bi8vDwmEonY+vXrGWOMXb16lQFgZ86cUZfZs2cP4/F47N69e4wxxn766Sdmbm6ujp0xxqZNm8Y8PT3V74cOHcoGDBigEU9wcDB777336hVLTcniSadPn2YAWGpqqnqcq6srW7x4ca3zNHfstSXoQYMG1RpTa9rngwYNYi+++KLGOF3
2024-06-09 10:35:26 -05:00
"text/plain": [
"<Figure size 500x300 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from previous_chapters import plot_losses\n",
"\n",
"epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
"plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)"
]
},
{
"cell_type": "markdown",
"id": "87b79a47-13f9-4d1f-87b1-3339bafaf2a3",
"metadata": {
"id": "87b79a47-13f9-4d1f-87b1-3339bafaf2a3"
},
2024-06-09 10:35:26 -05:00
"source": [
2024-06-10 08:20:12 -05:00
"## 7.6 Extracting and saving responses"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "code",
"execution_count": 33,
2024-06-09 10:35:26 -05:00
"id": "F9QyvnRipwNc",
"metadata": {
"id": "F9QyvnRipwNc"
},
"outputs": [],
"source": [
"def extract_response(response):\n",
2024-06-09 10:35:26 -05:00
" return response[response.find(\"\\n### Response\")+len(\"\\n### Response:\")+1:]"
]
},
{
"cell_type": "code",
"execution_count": 34,
2024-06-09 10:35:26 -05:00
"id": "VQ2NZMbfucAc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VQ2NZMbfucAc",
"outputId": "1fd28d43-3fd4-4d94-a63e-07f4a53f41b6"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"Rewrite the sentence using a simile.\n",
"\n",
"### Input:\n",
"The car is very fast.\n",
"\n",
"Correct response:\n",
">> The car is as fast as lightning.\n",
"\n",
"Model response:\n",
">> The car is as fast as a bullet.\n",
"-------------------------------------\n",
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"What type of cloud is typically associated with thunderstorms?\n",
"\n",
"Correct response:\n",
">> The type of cloud typically associated with thunderstorms is cumulonimbus.\n",
"\n",
"Model response:\n",
">> The type of cloud typically associated with thunderstorms is a cumulus (thin, water-filled, or gas-filled).\n",
2024-06-09 10:35:26 -05:00
"-------------------------------------\n",
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"Name the author of 'Pride and Prejudice'.\n",
"\n",
"Correct response:\n",
">> Jane Austen.\n",
"\n",
"Model response:\n",
">> The author of 'Pride and Prejudice' is Jane Austen.\n",
"-------------------------------------\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"\n",
"for entry in test_data[:3]:\n",
"\n",
" input_text = format_input(entry)\n",
"\n",
" token_ids = generate(\n",
" model=model,\n",
" idx=text_to_token_ids(input_text, tokenizer).to(device),\n",
" max_new_tokens=256,\n",
" context_size=BASE_CONFIG[\"context_length\"],\n",
" eos_id=50256\n",
" )\n",
" response = token_ids_to_text(token_ids, tokenizer)\n",
" response_text = extract_response(response)\n",
"\n",
" print(input_text)\n",
" print(f\"\\nCorrect response:\\n>> {entry['output']}\")\n",
" print(f\"\\nModel response:\\n>> {response_text.strip()}\")\n",
" print(\"-------------------------------------\")"
]
},
{
"cell_type": "code",
"execution_count": 35,
2024-06-09 10:35:26 -05:00
"id": "-PNGKzY4snKP",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-PNGKzY4snKP",
"outputId": "3e16caff-287a-4084-ed93-fcccd68e1da7"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 110/110 [01:17<00:00, 1.42it/s]\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"for i, entry in tqdm(enumerate(test_data), total=len(test_data)):\n",
"\n",
" input_text = format_input(entry)\n",
"\n",
" token_ids = generate(\n",
" model=model,\n",
" idx=text_to_token_ids(input_text, tokenizer).to(device),\n",
" max_new_tokens=256,\n",
" context_size=BASE_CONFIG[\"context_length\"],\n",
" eos_id=50256\n",
" )\n",
" response = token_ids_to_text(token_ids, tokenizer)\n",
" response_text = extract_response(response)\n",
"\n",
" test_data[i][\"model_response\"] = response_text\n",
"\n",
"\n",
"with open(\"instruction-data-with-response.json\", \"w\") as file:\n",
" json.dump(test_data, file, indent=4) # \"indent\" for pretty-printing"
]
},
{
"cell_type": "code",
"execution_count": 36,
2024-06-09 10:35:26 -05:00
"id": "u-AvCCMTnPSE",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "u-AvCCMTnPSE",
"outputId": "90c7f165-713e-4795-9205-f2f9b4d13313"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"data": {
"text/plain": [
"{'instruction': 'Rewrite the sentence using a simile.',\n",
" 'input': 'The car is very fast.',\n",
" 'output': 'The car is as fast as lightning.',\n",
" 'model_response': 'The car is as fast as a bullet.'}"
]
},
"execution_count": 36,
2024-06-09 10:35:26 -05:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_data[0]"
]
},
{
"cell_type": "code",
"execution_count": 37,
2024-06-09 10:35:26 -05:00
"id": "8cBU0iHmVfOI",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8cBU0iHmVfOI",
"outputId": "df6e862f-a6c8-4d23-ac3a-7645fd25a59d"
2024-06-09 10:35:26 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model saved as gpt2-medium355M-sft.pth\n"
]
}
],
"source": [
"import re\n",
"\n",
"file_name = f\"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft.pth\"\n",
"torch.save(model.state_dict(), file_name)\n",
"print(f\"Model saved as {file_name}\")"
]
},
{
"cell_type": "markdown",
"id": "obgoGI89dgPm",
"metadata": {
"id": "obgoGI89dgPm"
},
"source": [
"## 7.7 Evaluating the finetuned LLM"
]
},
{
"cell_type": "code",
"execution_count": 1,
2024-06-09 10:35:26 -05:00
"id": "026e8570-071e-48a2-aa38-64d7be35f288",
"metadata": {
"id": "026e8570-071e-48a2-aa38-64d7be35f288",
"outputId": "ad2e3f89-30a0-4f8b-9d6f-24acf6cf5153"
},
2024-06-09 10:35:26 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ollama running: True\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"import psutil\n",
"\n",
"def check_if_running(process_name):\n",
" running = False\n",
" for proc in psutil.process_iter([\"name\"]):\n",
" if process_name in proc.info[\"name\"]:\n",
" running = True\n",
" break\n",
" return running\n",
"\n",
"ollama_running = check_if_running(\"ollama\")\n",
"\n",
"if not ollama_running:\n",
" raise RuntimeError(\"Ollama not running. Launch ollama before proceeding.\")\n",
"print(\"Ollama running:\", check_if_running(\"ollama\"))"
2024-06-09 10:35:26 -05:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "723c9b00-e3cd-4092-83c3-6e48b5cf65b0",
2024-06-09 10:35:26 -05:00
"metadata": {},
"outputs": [],
"source": [
"# This cell is optional; it allows you to restart the notebook \n",
"# and only run section 7.7 without rerunning any of the previous cod\n",
"import json \n",
"from tqdm import tqdm\n",
"\n",
"file_path = \"instruction-data-with-response.json\"\n",
"\n",
"with open(file_path, \"r\") as file:\n",
" test_data = json.load(file)\n",
"\n",
"\n",
"def format_input(entry):\n",
" instruction_text = (\n",
" f\"Below is an instruction that describes a task. \"\n",
" f\"Write a response that appropriately completes the request.\"\n",
" f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
" )\n",
"\n",
" input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
"\n",
" return instruction_text + input_text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e3ae0e10-2b28-42ce-8ea2-d9366a58088f",
"metadata": {
"id": "e3ae0e10-2b28-42ce-8ea2-d9366a58088f",
"outputId": "9ca4ec2b-09d2-4447-da42-c1b81b93333a"
},
2024-06-09 10:35:26 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Llamas are ruminant animals, which means they have a four-chambered stomach and feed on plant-based foods. Their diet typically consists of:\n",
2024-06-09 10:35:26 -05:00
"\n",
"1. Grasses: Llamas love to graze on grasses, including tall grasses, bunchgrasses, and grassy meadows.\n",
"2. Hay: High-quality hay is a staple in many llama diets. Timothy hay, alfalfa hay, and oat hay are all popular choices.\n",
"3. Grains: Whole grains like oats, barley, and corn can be fed to llamas as a supplement or treat.\n",
"4. Leaves: Llamas enjoy munching on leaves from trees and shrubs, such as willow, cottonwood, and juniper.\n",
"5. Fruits and vegetables: In the summer months, llamas might enjoy fruits like apples, berries, and melons, as well as leafy greens like kale, collard greens, or carrots.\n",
"6. Pellets: A high-fiber pellet specifically formulated for llamas can be a convenient and nutritious addition to their diet.\n",
2024-06-09 10:35:26 -05:00
"\n",
"It's essential to provide llamas with access to fresh water at all times and ensure they have a reliable source of fiber-rich foods to maintain their digestive health. Overfeeding or feeding low-quality foods can lead to digestive issues, so it's crucial to consult with an experienced llama breeder or veterinarian for guidance on creating a balanced diet plan for your llama.\n"
2024-06-09 10:35:26 -05:00
]
}
],
"source": [
"import urllib.request\n",
"\n",
"def query_model(prompt, model=\"llama3\", url=\"http://localhost:11434/api/chat\"):\n",
" # Create the data payload as a dictionary\n",
" data = {\n",
" \"model\": model,\n",
" \"seed\": 123, # for deterministic responses\n",
" \"temperature\": 0, # for deterministic responses\n",
" \"messages\": [\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ]\n",
" }\n",
"\n",
" # Convert the dictionary to a JSON formatted string and encode it to bytes\n",
" payload = json.dumps(data).encode(\"utf-8\")\n",
"\n",
" # Create a request object, setting the method to POST and adding necessary headers\n",
" request = urllib.request.Request(url, data=payload, method=\"POST\")\n",
" request.add_header(\"Content-Type\", \"application/json\")\n",
"\n",
" # Send the request and capture the response\n",
" response_data = \"\"\n",
" with urllib.request.urlopen(request) as response:\n",
" # Read and decode the response\n",
" while True:\n",
" line = response.readline().decode(\"utf-8\")\n",
" if not line:\n",
" break\n",
" response_json = json.loads(line)\n",
" response_data += response_json[\"message\"][\"content\"]\n",
"\n",
" return response_data\n",
"\n",
"\n",
"model = \"llama3\"\n",
"result = query_model(\"What do Llamas eat?\", model)\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"id": "207ae28f-0f8c-4fda-aeef-e7e3046249cc",
"metadata": {
"id": "207ae28f-0f8c-4fda-aeef-e7e3046249cc"
},
2024-06-09 10:35:26 -05:00
"source": [
"- Using ollama with the `\"llama3\"` model (a 8B parameter model) requires 16 GB of RAM; if this is not supported by your machine, you can try the smaller model, such as the 3.8B parameter phi-3 model by setting `model = \"phi-3\"`, which only requires 8 Gb of RAM"
]
},
{
"cell_type": "code",
"execution_count": 4,
2024-06-09 10:35:26 -05:00
"id": "86b839d4-064d-4178-b2d7-01691b452e5e",
"metadata": {
"id": "86b839d4-064d-4178-b2d7-01691b452e5e",
"outputId": "6c003d5f-65e3-4316-861b-c35bae6b2ca7"
},
2024-06-09 10:35:26 -05:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Dataset response:\n",
">> The car is as fast as lightning.\n",
"\n",
"Model response:\n",
">> The car is as fast as a bullet.\n",
"\n",
"Score:\n",
">> To evaluate the model's response, I'll consider the following factors:\n",
2024-06-09 10:35:26 -05:00
"\n",
"1. Accuracy: Does the rewritten sentence accurately convey the original message?\n",
"2. Creativity: Is the chosen analogy unique and engaging?\n",
"3. Relevance: Is the comparison relevant to the original sentence?\n",
2024-06-09 10:35:26 -05:00
"\n",
"The model's response, \"The car is as fast as a bullet,\" scores high in accuracy (it conveys the idea that the car is very fast) and creativity (using a bullet as an analogy is unexpected). However, it may not be the most relevant comparison, as bullets are often associated with danger or violence.\n",
2024-06-09 10:35:26 -05:00
"\n",
"Using these criteria, I'd score the model's response around 85 out of 100. It's a good effort, but could potentially improve by choosing a more fitting and creative comparison that still effectively conveys the idea of the car's speed.\n",
2024-06-09 10:35:26 -05:00
"\n",
"-------------------------\n",
"\n",
"Dataset response:\n",
">> The type of cloud typically associated with thunderstorms is cumulonimbus.\n",
"\n",
"Model response:\n",
">> The type of cloud typically associated with thunderstorms is a cumulus (thin, water-filled, or gas-filled).\n",
2024-06-09 10:35:26 -05:00
"\n",
"Score:\n",
">> To evaluate the model's response, I'll consider its accuracy and completeness in addressing the original instruction.\n",
2024-06-09 10:35:26 -05:00
"\n",
"The model's response partially addresses the instruction by mentioning that cumulus clouds are associated with thunderstorms. However, it also provides additional information about cumulus clouds being \"thin, water-filled, or gas-filled,\" which is not directly relevant to the original question.\n",
2024-06-09 10:35:26 -05:00
"\n",
"Given these factors, I would score the model's response as 60 out of 100. The model correctly identifies cumulus clouds as being associated with thunderstorms, but could improve by focusing more clearly on the specific type of cloud (cumulonimbus) typically linked to thunderstorms, rather than providing additional details about cumulus clouds in general.\n",
2024-06-09 10:35:26 -05:00
"\n",
"-------------------------\n",
"\n",
"Dataset response:\n",
">> Jane Austen.\n",
"\n",
"Model response:\n",
">> The author of 'Pride and Prejudice' is Jane Austen.\n",
"\n",
"Score:\n",
">> A simple one!\n",
"\n",
"The input instruction asks me to \"Name the author of 'Pride and Prejudice'.\"\n",
"\n",
"My response: `Jane Austen.`\n",
"\n",
"And that's correct! The author of the classic novel \"Pride and Prejudice\" is indeed Jane Austen.\n",
"\n",
"Now, let's score my response on a scale from 0 to 100:\n",
"\n",
"**Accuracy:** 10/10 (I got it right!)\n",
2024-06-09 10:35:26 -05:00
"\n",
"**Clarity:** 9/10 (My response was brief and to the point.)\n",
2024-06-09 10:35:26 -05:00
"\n",
"**Relevance:** 10/10 (The answer is directly related to the question.)\n",
2024-06-09 10:35:26 -05:00
"\n",
"**Overall:** 92/100\n",
"\n",
"So, my score for this response is a solid 92 out of 100!\n",
2024-06-09 10:35:26 -05:00
"\n",
"-------------------------\n"
]
}
],
"source": [
"for entry in test_data[:3]:\n",
" prompt = (\n",
" f\"Given the input `{format_input(entry)}` \"\n",
" f\"and correct output `{entry['output']}`, \"\n",
" f\"score the model response `{entry['model_response']}`\"\n",
" f\" on a scale from 0 to 100, where 100 is the best score. \"\n",
" )\n",
" print(\"\\nDataset response:\")\n",
" print(\">>\", entry['output'])\n",
" print(\"\\nModel response:\")\n",
" print(\">>\", entry[\"model_response\"])\n",
" print(\"\\nScore:\")\n",
" print(\">>\", query_model(prompt))\n",
" print(\"\\n-------------------------\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
2024-06-09 10:35:26 -05:00
"id": "9d7bca69-97c4-47a5-9aa0-32f116fa37eb",
"metadata": {
"id": "9d7bca69-97c4-47a5-9aa0-32f116fa37eb",
"outputId": "bf585ec4-0f49-4bc7-89e3-6b47828ac6d4"
},
2024-06-09 10:35:26 -05:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Scoring entries: 100%|████████████████████████| 110/110 [01:11<00:00, 1.55it/s]"
2024-06-09 10:35:26 -05:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of scores: 110 of 110\n",
"Average score: 52.88\n",
2024-06-09 10:35:26 -05:00
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"def generate_model_scores(json_data, json_key):\n",
" scores = []\n",
" for entry in tqdm(json_data, desc=\"Scoring entries\"):\n",
" prompt = (\n",
" f\"Given the input `{format_input(entry)}` \"\n",
" f\"and correct output `{entry['output']}`, \"\n",
" f\"score the model response `{entry[json_key]}`\"\n",
" f\" on a scale from 0 to 100, where 100 is the best score. \"\n",
" f\"Respond with the integer number only.\"\n",
" )\n",
" score = query_model(prompt)\n",
" try:\n",
" scores.append(int(score))\n",
" except ValueError:\n",
" print(f\"Could not convert score: {score}\")\n",
" continue\n",
"\n",
" return scores\n",
"\n",
"\n",
"scores = generate_model_scores(test_data, \"model_response\")\n",
"print(f\"Number of scores: {len(scores)} of {len(test_data)}\")\n",
"print(f\"Average score: {sum(scores)/len(scores):.2f}\\n\")"
]
},
{
"cell_type": "markdown",
"id": "6408768b-2784-44f1-b48e-aed0c1eb9b94",
"metadata": {
"id": "6408768b-2784-44f1-b48e-aed0c1eb9b94"
},
"source": [
"- For reference, the original\n",
" - Llama 3 8B base model achieves a score of 58.51\n",
" - Llama 3 8B instruct model achieves a score of 82.65"
]
},
2024-06-09 10:35:26 -05:00
{
"cell_type": "markdown",
"id": "412d7325-284a-446c-92a1-5aa8acc52dee",
"metadata": {
"id": "412d7325-284a-446c-92a1-5aa8acc52dee"
2024-06-09 10:35:26 -05:00
},
"source": [
"## 7.8 Conclusions"
]
},
{
"cell_type": "markdown",
"id": "f9853e7f-a81a-4806-9728-be1690807185",
"metadata": {
"id": "f9853e7f-a81a-4806-9728-be1690807185"
},
2024-06-09 10:35:26 -05:00
"source": [
"## Summary"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "L4",
"machine_shape": "hm",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
2024-06-09 10:35:26 -05:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}