mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-07-03 07:04:25 +00:00
1668 lines
87 KiB
Plaintext
1668 lines
87 KiB
Plaintext
![]() |
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "12e91914-5f51-43fa-b65b-625e73b4d17b",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"<table style=\"width:100%\">\n",
|
||
|
"<tr>\n",
|
||
|
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
||
|
"<font size=\"2\">\n",
|
||
|
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
|
||
|
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
|
||
|
"</font>\n",
|
||
|
"</td>\n",
|
||
|
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
||
|
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
|
||
|
"</td>\n",
|
||
|
"</tr>\n",
|
||
|
"</table>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "c2520ec3-722f-4f44-bdd1-885b13e7afbf",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Chapter 7: Finetuning To Follow Instructions"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "a984b9ef-af93-415a-9ec7-97385f28af7b",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"- Comments & notes in progress ..."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "4e19327b-6c02-4881-ad02-9b6d3ec0b1b4",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "4e19327b-6c02-4881-ad02-9b6d3ec0b1b4",
|
||
|
"outputId": "538e79af-011b-4a60-f288-2d0312a2b5a6"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"matplotlib version: 3.8.4\n",
|
||
|
"tiktoken version: 0.6.0\n",
|
||
|
"torch version: 2.2.2\n",
|
||
|
"tqdm version: 4.66.2\n",
|
||
|
"tensorflow version: 2.16.1\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from importlib.metadata import version\n",
|
||
|
"\n",
|
||
|
"pkgs = [\n",
|
||
|
" \"matplotlib\", # Plotting library\n",
|
||
|
" \"tiktoken\", # Tokenizer\n",
|
||
|
" \"torch\", # Deep learning library\n",
|
||
|
" \"tqdm\", # Progress bar\n",
|
||
|
" \"tensorflow\", # For OpenAI's pretrained weights\n",
|
||
|
"]\n",
|
||
|
"for p in pkgs:\n",
|
||
|
" print(f\"{p} version: {version(p)}\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "8bbc68e9-75b3-41f1-ac2c-e071c3cd0813",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 7.1 Introduction to instruction finetuning"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "5384f0cf-ef3c-4436-a5fa-59bd25649f86",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 7.2 Preparing a dataset for supervised instruction finetuning"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "0G3axLw6kY1N",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "0G3axLw6kY1N",
|
||
|
"outputId": "2a9a1c83-9c46-49a5-f9df-fce3320f7db2"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"1100\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import json\n",
|
||
|
"import os\n",
|
||
|
"import urllib\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def download_and_load_file(file_path, url):\n",
|
||
|
"\n",
|
||
|
" if not os.path.exists(file_path):\n",
|
||
|
" with urllib.request.urlopen(url) as response:\n",
|
||
|
" text_data = response.read().decode('utf-8')\n",
|
||
|
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
|
||
|
" file.write(text_data)\n",
|
||
|
" else:\n",
|
||
|
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
||
|
" text_data = file.read()\n",
|
||
|
"\n",
|
||
|
" with open(file_path, \"r\") as file:\n",
|
||
|
" data = json.load(file)\n",
|
||
|
"\n",
|
||
|
" return data\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"file_path = \"instruction-data.json\"\n",
|
||
|
"url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json\"\n",
|
||
|
"\n",
|
||
|
"data = download_and_load_file(file_path, url)\n",
|
||
|
"print(len(data))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "-LiuBMsHkzQV",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "-LiuBMsHkzQV",
|
||
|
"outputId": "fc3b22fd-9a53-405e-9c25-2a5873d343d1"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"{'instruction': 'Evaluate the following phrase by transforming it into the spelling given.', 'input': 'freind --> friend', 'output': 'The spelling of the given phrase \"freind\" is incorrect, the correct spelling is \"friend\".'}\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(data[0])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "uFInFxDDk2Je",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "uFInFxDDk2Je",
|
||
|
"outputId": "84cb1aad-233a-488a-f6b0-6cb977834367"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"{'instruction': \"Change the sentence 'You should have called me.' into a question.\", 'input': '', 'output': 'Should you have called me?'}\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(data[-1])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "aFZVopbIlNfx",
|
||
|
"metadata": {
|
||
|
"id": "aFZVopbIlNfx"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_portion = int(len(data) * 0.85) # 85% for training\n",
|
||
|
"test_portion = int(len(data) * 0.1) # 10% for testing\n",
|
||
|
"val_portion = len(data) - train_portion - test_portion # Remaining 5% for validation\n",
|
||
|
"\n",
|
||
|
"train_data = data[:train_portion]\n",
|
||
|
"test_data = data[train_portion:train_portion + test_portion]\n",
|
||
|
"val_data = data[train_portion + test_portion:]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "-zf6oht6bIUQ",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "-zf6oht6bIUQ",
|
||
|
"outputId": "bf33cd9a-2778-4365-c51d-d394c817c4fb"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"935\n",
|
||
|
"55\n",
|
||
|
"110\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(len(train_data))\n",
|
||
|
"print(len(val_data))\n",
|
||
|
"print(len(test_data))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "Jhk37nnJnkBh",
|
||
|
"metadata": {
|
||
|
"id": "Jhk37nnJnkBh"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def format_input(entry):\n",
|
||
|
" instruction_text = (\n",
|
||
|
" f\"Below is an instruction that describes a task. \"\n",
|
||
|
" f\"Write a response that appropriately completes the request.\"\n",
|
||
|
" f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
|
||
|
" instruction_text + input_text\n",
|
||
|
"\n",
|
||
|
" return instruction_text + input_text"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "F9UQRfjzo4Js",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "F9UQRfjzo4Js",
|
||
|
"outputId": "b56e6c03-f603-4e9d-c1b6-b4a70403caf9"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
|
||
|
"\n",
|
||
|
"### Instruction:\n",
|
||
|
"Evaluate the following phrase by transforming it into the spelling given.\n",
|
||
|
"\n",
|
||
|
"### Input:\n",
|
||
|
"freind --> friend\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(format_input(train_data[0]))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "fcaaf606-f913-4445-8301-632ae10d387d",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 7.3 Creating data loaders for an instruction dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "K6MWf0lhu8GP",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "K6MWf0lhu8GP",
|
||
|
"outputId": "bb01c511-4023-4b74-9781-8385da75b391"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[50256]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import tiktoken\n",
|
||
|
"\n",
|
||
|
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||
|
"print(tokenizer.encode(\"<|endoftext|>\", allowed_special={\"<|endoftext|>\"}))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "adc29dc4-f1c7-4c71-937b-95119d6239bb",
|
||
|
"metadata": {
|
||
|
"id": "adc29dc4-f1c7-4c71-937b-95119d6239bb"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"from torch.utils.data import Dataset\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class InstructionDataset(Dataset):\n",
|
||
|
" def __init__(self, data, tokenizer):\n",
|
||
|
" self.data = data\n",
|
||
|
"\n",
|
||
|
" # Pre-tokenize texts\n",
|
||
|
" self.encoded_texts = []\n",
|
||
|
" for entry in data:\n",
|
||
|
" instruction_plus_input = format_input(entry)\n",
|
||
|
" response_text = f\"\\n\\n### Response:\\n{entry['output']}\"\n",
|
||
|
" full_text = instruction_plus_input + response_text\n",
|
||
|
" self.encoded_texts.append(\n",
|
||
|
" tokenizer.encode(\n",
|
||
|
" full_text, allowed_special={\"<|endoftext|>\"}\n",
|
||
|
" )\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def __getitem__(self, index):\n",
|
||
|
" return self.encoded_texts[index]\n",
|
||
|
"\n",
|
||
|
" def __len__(self):\n",
|
||
|
" return len(self.data)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "W2jvh-OP9MFV",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "W2jvh-OP9MFV",
|
||
|
"outputId": "7878ef5f-635a-491a-99b2-07b3319daefc"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor(1.1269)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Explain index masking\n",
|
||
|
"\n",
|
||
|
"targets = torch.tensor([0, 1])\n",
|
||
|
"inputs = torch.tensor(\n",
|
||
|
" [[-1., 1.],\n",
|
||
|
" [-0.5, 1.5]]\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"torch.nn.functional.cross_entropy(inputs, targets)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "nvVMuil89v9N",
|
||
|
"metadata": {
|
||
|
"id": "nvVMuil89v9N"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor(0.7936)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"targets = torch.tensor([0, 1, 1])\n",
|
||
|
"inputs = torch.tensor(\n",
|
||
|
" [[-1., 1.],\n",
|
||
|
" [-0.5, 1.5],\n",
|
||
|
" [-0.5, 1.5]]\n",
|
||
|
")\n",
|
||
|
"torch.nn.functional.cross_entropy(inputs, targets)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "RTyB1vah9p56",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "RTyB1vah9p56",
|
||
|
"outputId": "f1c132ad-85db-411d-cfc8-1d9ab3aec79d"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor(1.1269)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"targets = torch.tensor([0, 1, -100])\n",
|
||
|
"inputs = torch.tensor(\n",
|
||
|
" [[-1., 1.],\n",
|
||
|
" [-0.5, 1.5],\n",
|
||
|
" [-0.5, 1.5]]\n",
|
||
|
")\n",
|
||
|
"torch.nn.functional.cross_entropy(inputs, targets)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "41ec6e2d-9eb2-4124-913e-d2af39be4cf2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def custom_collate_fn(\n",
|
||
|
" batch, \n",
|
||
|
" pad_token_id=50256,\n",
|
||
|
" ignore_index=-100,\n",
|
||
|
" allowed_max_length=None, \n",
|
||
|
" device=\"cpu\"\n",
|
||
|
"):\n",
|
||
|
" # Find the longest sequence in the batch\n",
|
||
|
" batch_max_length = max(len(item) for item in batch)\n",
|
||
|
"\n",
|
||
|
" # Pad and prepare inputs and targets\n",
|
||
|
" inputs_lst, targets_lst = [], []\n",
|
||
|
" for item in batch:\n",
|
||
|
" # Pad sequences to max_length\n",
|
||
|
" padded = item + [pad_token_id] * (batch_max_length - len(item))\n",
|
||
|
" inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs\n",
|
||
|
" targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets\n",
|
||
|
"\n",
|
||
|
" # Replace all but the first padding tokens in targets by ignore_index\n",
|
||
|
" mask = targets == pad_token_id\n",
|
||
|
" indices = torch.nonzero(mask).squeeze()\n",
|
||
|
" if indices.numel() > 1:\n",
|
||
|
" targets[indices[1:]] = ignore_index\n",
|
||
|
"\n",
|
||
|
" # Optionally truncate to maximum sequence length\n",
|
||
|
" if allowed_max_length is not None:\n",
|
||
|
" inputs = inputs[:allowed_max_length]\n",
|
||
|
" targets = targets[:allowed_max_length]\n",
|
||
|
"\n",
|
||
|
" inputs_lst.append(inputs)\n",
|
||
|
" targets_lst.append(targets)\n",
|
||
|
"\n",
|
||
|
" inputs_tensor = torch.stack(inputs_lst).to(device)\n",
|
||
|
" targets_tensor = torch.stack(targets_lst).to(device)\n",
|
||
|
"\n",
|
||
|
" return inputs_tensor, targets_tensor"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"id": "cdf5eec4-9ebe-4be0-9fca-9a47bee88fdc",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(tensor([[ 0, 1, 2, 3, 4, 50256],\n",
|
||
|
" [ 7, 8, 9, 50256, 50256, 50256]]),\n",
|
||
|
" tensor([[ 1, 2, 3, 4, 50256, -100],\n",
|
||
|
" [ 8, 9, 50256, -100, -100, -100]]))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 16,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"inputs_1 = [0, 1, 2, 3, 4, 50256, 50256]\n",
|
||
|
"inputs_2 = [7, 8, 9]\n",
|
||
|
"\n",
|
||
|
"batch = (\n",
|
||
|
" inputs_1,\n",
|
||
|
" inputs_2\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"custom_collate_fn(batch)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"id": "etpqqWh8phKc",
|
||
|
"metadata": {
|
||
|
"id": "etpqqWh8phKc"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Device: cpu\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from functools import partial\n",
|
||
|
"\n",
|
||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
|
"print(\"Device:\", device)\n",
|
||
|
"\n",
|
||
|
"customized_collate_fn = partial(custom_collate_fn, device=device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"id": "BtWkgir6Hlpe",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "BtWkgir6Hlpe",
|
||
|
"outputId": "8e3a969d-e1f6-4574-cc07-3f8401068555"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from torch.utils.data import DataLoader\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"num_workers = 0\n",
|
||
|
"batch_size = 8\n",
|
||
|
"\n",
|
||
|
"torch.manual_seed(123)\n",
|
||
|
"\n",
|
||
|
"train_dataset = InstructionDataset(train_data, tokenizer)\n",
|
||
|
"train_loader = DataLoader(\n",
|
||
|
" train_dataset,\n",
|
||
|
" batch_size=batch_size,\n",
|
||
|
" collate_fn=customized_collate_fn,\n",
|
||
|
" shuffle=True,\n",
|
||
|
" drop_last=True\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"id": "1d097dc8-ad34-4f05-b435-e4147965f532",
|
||
|
"metadata": {
|
||
|
"id": "1d097dc8-ad34-4f05-b435-e4147965f532"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"val_dataset = InstructionDataset(val_data, tokenizer)\n",
|
||
|
"val_loader = DataLoader(\n",
|
||
|
" val_dataset,\n",
|
||
|
" batch_size=batch_size,\n",
|
||
|
" collate_fn=customized_collate_fn,\n",
|
||
|
" shuffle=False,\n",
|
||
|
" drop_last=False\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"test_dataset = InstructionDataset(test_data, tokenizer)\n",
|
||
|
"test_loader = DataLoader(\n",
|
||
|
" test_dataset,\n",
|
||
|
" batch_size=batch_size,\n",
|
||
|
" collate_fn=customized_collate_fn,\n",
|
||
|
" shuffle=False,\n",
|
||
|
" drop_last=False\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"id": "GGs1AI3vHpnX",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "GGs1AI3vHpnX",
|
||
|
"outputId": "eaabe39c-bb78-4fec-979c-6382c192a79f"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Train loader:\n",
|
||
|
"torch.Size([8, 60]) torch.Size([8, 60])\n",
|
||
|
"torch.Size([8, 75]) torch.Size([8, 75])\n",
|
||
|
"torch.Size([8, 72]) torch.Size([8, 72])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 64]) torch.Size([8, 64])\n",
|
||
|
"torch.Size([8, 71]) torch.Size([8, 71])\n",
|
||
|
"torch.Size([8, 79]) torch.Size([8, 79])\n",
|
||
|
"torch.Size([8, 66]) torch.Size([8, 66])\n",
|
||
|
"torch.Size([8, 61]) torch.Size([8, 61])\n",
|
||
|
"torch.Size([8, 74]) torch.Size([8, 74])\n",
|
||
|
"torch.Size([8, 61]) torch.Size([8, 61])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 66]) torch.Size([8, 66])\n",
|
||
|
"torch.Size([8, 76]) torch.Size([8, 76])\n",
|
||
|
"torch.Size([8, 68]) torch.Size([8, 68])\n",
|
||
|
"torch.Size([8, 78]) torch.Size([8, 78])\n",
|
||
|
"torch.Size([8, 70]) torch.Size([8, 70])\n",
|
||
|
"torch.Size([8, 65]) torch.Size([8, 65])\n",
|
||
|
"torch.Size([8, 82]) torch.Size([8, 82])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 79]) torch.Size([8, 79])\n",
|
||
|
"torch.Size([8, 70]) torch.Size([8, 70])\n",
|
||
|
"torch.Size([8, 68]) torch.Size([8, 68])\n",
|
||
|
"torch.Size([8, 64]) torch.Size([8, 64])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 59]) torch.Size([8, 59])\n",
|
||
|
"torch.Size([8, 58]) torch.Size([8, 58])\n",
|
||
|
"torch.Size([8, 68]) torch.Size([8, 68])\n",
|
||
|
"torch.Size([8, 62]) torch.Size([8, 62])\n",
|
||
|
"torch.Size([8, 64]) torch.Size([8, 64])\n",
|
||
|
"torch.Size([8, 75]) torch.Size([8, 75])\n",
|
||
|
"torch.Size([8, 65]) torch.Size([8, 65])\n",
|
||
|
"torch.Size([8, 70]) torch.Size([8, 70])\n",
|
||
|
"torch.Size([8, 90]) torch.Size([8, 90])\n",
|
||
|
"torch.Size([8, 64]) torch.Size([8, 64])\n",
|
||
|
"torch.Size([8, 63]) torch.Size([8, 63])\n",
|
||
|
"torch.Size([8, 66]) torch.Size([8, 66])\n",
|
||
|
"torch.Size([8, 65]) torch.Size([8, 65])\n",
|
||
|
"torch.Size([8, 63]) torch.Size([8, 63])\n",
|
||
|
"torch.Size([8, 64]) torch.Size([8, 64])\n",
|
||
|
"torch.Size([8, 74]) torch.Size([8, 74])\n",
|
||
|
"torch.Size([8, 88]) torch.Size([8, 88])\n",
|
||
|
"torch.Size([8, 58]) torch.Size([8, 58])\n",
|
||
|
"torch.Size([8, 87]) torch.Size([8, 87])\n",
|
||
|
"torch.Size([8, 82]) torch.Size([8, 82])\n",
|
||
|
"torch.Size([8, 82]) torch.Size([8, 82])\n",
|
||
|
"torch.Size([8, 69]) torch.Size([8, 69])\n",
|
||
|
"torch.Size([8, 64]) torch.Size([8, 64])\n",
|
||
|
"torch.Size([8, 73]) torch.Size([8, 73])\n",
|
||
|
"torch.Size([8, 75]) torch.Size([8, 75])\n",
|
||
|
"torch.Size([8, 66]) torch.Size([8, 66])\n",
|
||
|
"torch.Size([8, 74]) torch.Size([8, 74])\n",
|
||
|
"torch.Size([8, 82]) torch.Size([8, 82])\n",
|
||
|
"torch.Size([8, 68]) torch.Size([8, 68])\n",
|
||
|
"torch.Size([8, 66]) torch.Size([8, 66])\n",
|
||
|
"torch.Size([8, 59]) torch.Size([8, 59])\n",
|
||
|
"torch.Size([8, 59]) torch.Size([8, 59])\n",
|
||
|
"torch.Size([8, 65]) torch.Size([8, 65])\n",
|
||
|
"torch.Size([8, 79]) torch.Size([8, 79])\n",
|
||
|
"torch.Size([8, 70]) torch.Size([8, 70])\n",
|
||
|
"torch.Size([8, 60]) torch.Size([8, 60])\n",
|
||
|
"torch.Size([8, 57]) torch.Size([8, 57])\n",
|
||
|
"torch.Size([8, 70]) torch.Size([8, 70])\n",
|
||
|
"torch.Size([8, 66]) torch.Size([8, 66])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 62]) torch.Size([8, 62])\n",
|
||
|
"torch.Size([8, 86]) torch.Size([8, 86])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 63]) torch.Size([8, 63])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 70]) torch.Size([8, 70])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 70]) torch.Size([8, 70])\n",
|
||
|
"torch.Size([8, 60]) torch.Size([8, 60])\n",
|
||
|
"torch.Size([8, 64]) torch.Size([8, 64])\n",
|
||
|
"torch.Size([8, 66]) torch.Size([8, 66])\n",
|
||
|
"torch.Size([8, 64]) torch.Size([8, 64])\n",
|
||
|
"torch.Size([8, 63]) torch.Size([8, 63])\n",
|
||
|
"torch.Size([8, 59]) torch.Size([8, 59])\n",
|
||
|
"torch.Size([8, 71]) torch.Size([8, 71])\n",
|
||
|
"torch.Size([8, 63]) torch.Size([8, 63])\n",
|
||
|
"torch.Size([8, 69]) torch.Size([8, 69])\n",
|
||
|
"torch.Size([8, 56]) torch.Size([8, 56])\n",
|
||
|
"torch.Size([8, 71]) torch.Size([8, 71])\n",
|
||
|
"torch.Size([8, 63]) torch.Size([8, 63])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 61]) torch.Size([8, 61])\n",
|
||
|
"torch.Size([8, 73]) torch.Size([8, 73])\n",
|
||
|
"torch.Size([8, 79]) torch.Size([8, 79])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 69]) torch.Size([8, 69])\n",
|
||
|
"torch.Size([8, 90]) torch.Size([8, 90])\n",
|
||
|
"torch.Size([8, 60]) torch.Size([8, 60])\n",
|
||
|
"torch.Size([8, 65]) torch.Size([8, 65])\n",
|
||
|
"torch.Size([8, 79]) torch.Size([8, 79])\n",
|
||
|
"torch.Size([8, 80]) torch.Size([8, 80])\n",
|
||
|
"torch.Size([8, 73]) torch.Size([8, 73])\n",
|
||
|
"torch.Size([8, 81]) torch.Size([8, 81])\n",
|
||
|
"torch.Size([8, 62]) torch.Size([8, 62])\n",
|
||
|
"torch.Size([8, 82]) torch.Size([8, 82])\n",
|
||
|
"torch.Size([8, 67]) torch.Size([8, 67])\n",
|
||
|
"torch.Size([8, 66]) torch.Size([8, 66])\n",
|
||
|
"torch.Size([8, 76]) torch.Size([8, 76])\n",
|
||
|
"torch.Size([8, 90]) torch.Size([8, 90])\n",
|
||
|
"torch.Size([8, 63]) torch.Size([8, 63])\n",
|
||
|
"torch.Size([8, 60]) torch.Size([8, 60])\n",
|
||
|
"torch.Size([8, 74]) torch.Size([8, 74])\n",
|
||
|
"torch.Size([8, 63]) torch.Size([8, 63])\n",
|
||
|
"torch.Size([8, 65]) torch.Size([8, 65])\n",
|
||
|
"torch.Size([8, 77]) torch.Size([8, 77])\n",
|
||
|
"torch.Size([8, 65]) torch.Size([8, 65])\n",
|
||
|
"torch.Size([8, 63]) torch.Size([8, 63])\n",
|
||
|
"torch.Size([8, 82]) torch.Size([8, 82])\n",
|
||
|
"torch.Size([8, 65]) torch.Size([8, 65])\n",
|
||
|
"torch.Size([8, 73]) torch.Size([8, 73])\n",
|
||
|
"torch.Size([8, 68]) torch.Size([8, 68])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(\"Train loader:\")\n",
|
||
|
"for x, y in train_loader:\n",
|
||
|
" print(x.shape, y.shape)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "21b8fd02-014f-4481-9b71-5bfee8f9dfcd",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "21b8fd02-014f-4481-9b71-5bfee8f9dfcd",
|
||
|
"outputId": "71ce098a-36b7-44fa-8c7c-f63db448fe40"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 30003, 6525, 262, 6827, 1262, 257, 985, 576, 13, 198, 198, 21017, 23412, 25, 198, 464, 5156, 318, 845, 13779, 13, 198, 198, 21017, 18261, 25, 198, 464, 5156, 318, 355, 13779, 355, 257, 4936, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, "
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"for i in x[0]:\n",
|
||
|
" print(i.item(), end=\", \")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"id": "51649ab4-1a7e-4a9e-92c5-950a24fde211",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "51649ab4-1a7e-4a9e-92c5-950a24fde211",
|
||
|
"outputId": "4cf98eac-b7f7-4687-b264-4508c0865865"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 30003, 6525, 262, 6827, 1262, 257, 985, 576, 13, 198, 198, 21017, 23412, 25, 198, 464, 5156, 318, 845, 13779, 13, 198, 198, 21017, 18261, 25, 198, 464, 5156, 318, 355, 13779, 355, 257, 4936, 13, 50256, -100, -100, -100, -100, -100, -100, -100, -100, "
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"for i in y[0]:\n",
|
||
|
" print(i.item(), end=\", \")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "d6aad445-8f19-4238-b9bf-db80767fb91a",
|
||
|
"metadata": {
|
||
|
"id": "d6aad445-8f19-4238-b9bf-db80767fb91a"
|
||
|
},
|
||
|
"source": [
|
||
|
"## 7.4 Loading a pretrained LLM"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"id": "0d249d67-5eba-414e-9bd2-972ebf01329d",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "0d249d67-5eba-414e-9bd2-972ebf01329d",
|
||
|
"outputId": "0ccd8d13-4f8a-44ce-ea22-0b6ea36bb06e"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"File already exists and is up-to-date: gpt2/355M/checkpoint\n",
|
||
|
"File already exists and is up-to-date: gpt2/355M/encoder.json\n",
|
||
|
"File already exists and is up-to-date: gpt2/355M/hparams.json\n",
|
||
|
"File already exists and is up-to-date: gpt2/355M/model.ckpt.data-00000-of-00001\n",
|
||
|
"File already exists and is up-to-date: gpt2/355M/model.ckpt.index\n",
|
||
|
"File already exists and is up-to-date: gpt2/355M/model.ckpt.meta\n",
|
||
|
"File already exists and is up-to-date: gpt2/355M/vocab.bpe\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from gpt_download import download_and_load_gpt2\n",
|
||
|
"from previous_chapters import GPTModel, load_weights_into_gpt\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"BASE_CONFIG = {\n",
|
||
|
" \"vocab_size\": 50257, # Vocabulary size\n",
|
||
|
" \"context_length\": 1024, # Context length\n",
|
||
|
" \"drop_rate\": 0.0, # Dropout rate\n",
|
||
|
" \"qkv_bias\": True # Query-key-value bias\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"model_configs = {\n",
|
||
|
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
|
||
|
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
|
||
|
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
|
||
|
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"CHOOSE_MODEL = \"gpt2-medium (355M)\"\n",
|
||
|
"\n",
|
||
|
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
|
||
|
"\n",
|
||
|
"model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n",
|
||
|
"settings, params = download_and_load_gpt2(model_size=model_size, models_dir=\"gpt2\")\n",
|
||
|
"\n",
|
||
|
"model = GPTModel(BASE_CONFIG)\n",
|
||
|
"load_weights_into_gpt(model, params)\n",
|
||
|
"model.eval();"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"id": "7bd32b7c-5b44-4d25-a09f-46836802ca74",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "7bd32b7c-5b44-4d25-a09f-46836802ca74",
|
||
|
"outputId": "de446b9d-7667-48a5-c34a-f3c5cf70459b"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
|
||
|
"\n",
|
||
|
"### Instruction:\n",
|
||
|
"Convert the active sentence to passive: 'The chef cooks the meal every day.'\n",
|
||
|
"\n",
|
||
|
"### Response:\n",
|
||
|
"\n",
|
||
|
"The chef cooks the meal every day.\n",
|
||
|
"\n",
|
||
|
"### Instruction:\n",
|
||
|
"\n",
|
||
|
"Convert the active sentence to passive: 'The chef cooks the\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from previous_chapters import (\n",
|
||
|
" generate,\n",
|
||
|
" text_to_token_ids,\n",
|
||
|
" token_ids_to_text\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"torch.manual_seed(123)\n",
|
||
|
"\n",
|
||
|
"token_ids = generate(\n",
|
||
|
" model=model,\n",
|
||
|
" idx=text_to_token_ids(format_input(val_data[0]), tokenizer),\n",
|
||
|
" max_new_tokens=35,\n",
|
||
|
" context_size=BASE_CONFIG[\"context_length\"],\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"print(token_ids_to_text(token_ids, tokenizer))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "70d27b9d-a942-4cf5-b797-848c5f01e723",
|
||
|
"metadata": {
|
||
|
"id": "70d27b9d-a942-4cf5-b797-848c5f01e723"
|
||
|
},
|
||
|
"source": [
|
||
|
"## 7.5 Finetuning the LLM on instruction data"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"id": "65444865-df87-4d98-9faf-875e1c4be860",
|
||
|
"metadata": {
|
||
|
"id": "65444865-df87-4d98-9faf-875e1c4be860"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from previous_chapters import (\n",
|
||
|
" calc_loss_loader,\n",
|
||
|
" train_model_simple\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"id": "d99fc6f8-63b2-43da-adbb-a7b6b92c8dd5",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "d99fc6f8-63b2-43da-adbb-a7b6b92c8dd5",
|
||
|
"outputId": "0c815e75-9357-42e6-fdf3-3ea13ffa4da4"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Training loss: 3.8234103202819822\n",
|
||
|
"Validation loss: 3.7612109184265137\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.to(device)\n",
|
||
|
"\n",
|
||
|
"torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n",
|
||
|
"\n",
|
||
|
"with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet\n",
|
||
|
" train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)\n",
|
||
|
" val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)\n",
|
||
|
"\n",
|
||
|
"print(\"Training loss:\", train_loss)\n",
|
||
|
"print(\"Validation loss:\", val_loss)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "db4b57fb-e689-4550-931c-6d34a932487c",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"- Runtimes:\n",
|
||
|
"\n",
|
||
|
"<div style=\"text-align: left;\">\n",
|
||
|
" \n",
|
||
|
"| Model | Platform | Runtime |\n",
|
||
|
"|--------------------|-----------------------|----------------|\n",
|
||
|
"| gpt2-medium (355M) | CPU (M3 MacBook Air) | 23.67 minutes |\n",
|
||
|
"| gpt2-medium (355M) | GPU (A100) | 1.29 minutes |\n",
|
||
|
"| gpt2-small (124M) | CPU (M3 MacBook Air) | 8.61 |\n",
|
||
|
"| gpt2-small (124M) | GPU (A100) | 0.59 minutes |\n",
|
||
|
"\n",
|
||
|
"</div>\n",
|
||
|
"\n",
|
||
|
"- Remainder of the notebook was run on M3 MacBook Air with the `\"gpt2-medium (355M)\"` model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "78bcf83a-1fff-4540-97c1-765c4016d5e3",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "78bcf83a-1fff-4540-97c1-765c4016d5e3",
|
||
|
"outputId": "315368d9-5484-4527-f42d-b0d650d6aa23"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Ep 1 (Step 000000): Train loss 2.636, Val loss 2.627\n",
|
||
|
"Ep 1 (Step 000005): Train loss 1.173, Val loss 1.103\n",
|
||
|
"Ep 1 (Step 000010): Train loss 0.873, Val loss 0.947\n",
|
||
|
"Ep 1 (Step 000015): Train loss 0.856, Val loss 0.907\n",
|
||
|
"Ep 1 (Step 000020): Train loss 0.777, Val loss 0.882\n",
|
||
|
"Ep 1 (Step 000025): Train loss 0.754, Val loss 0.860\n",
|
||
|
"Ep 1 (Step 000030): Train loss 0.799, Val loss 0.838\n",
|
||
|
"Ep 1 (Step 000035): Train loss 0.715, Val loss 0.810\n",
|
||
|
"Ep 1 (Step 000040): Train loss 0.673, Val loss 0.807\n",
|
||
|
"Ep 1 (Step 000045): Train loss 0.634, Val loss 0.791\n",
|
||
|
"Ep 1 (Step 000050): Train loss 0.663, Val loss 0.784\n",
|
||
|
"Ep 1 (Step 000055): Train loss 0.760, Val loss 0.764\n",
|
||
|
"Ep 1 (Step 000060): Train loss 0.721, Val loss 0.745\n",
|
||
|
"Ep 1 (Step 000065): Train loss 0.654, Val loss 0.736\n",
|
||
|
"Ep 1 (Step 000070): Train loss 0.535, Val loss 0.730\n",
|
||
|
"Ep 1 (Step 000075): Train loss 0.569, Val loss 0.729\n",
|
||
|
"Ep 1 (Step 000080): Train loss 0.606, Val loss 0.726\n",
|
||
|
"Ep 1 (Step 000085): Train loss 0.511, Val loss 0.710\n",
|
||
|
"Ep 1 (Step 000090): Train loss 0.563, Val loss 0.691\n",
|
||
|
"Ep 1 (Step 000095): Train loss 0.501, Val loss 0.682\n",
|
||
|
"Ep 1 (Step 000100): Train loss 0.504, Val loss 0.678\n",
|
||
|
"Ep 1 (Step 000105): Train loss 0.566, Val loss 0.671\n",
|
||
|
"Ep 1 (Step 000110): Train loss 0.556, Val loss 0.668\n",
|
||
|
"Ep 1 (Step 000115): Train loss 0.509, Val loss 0.665\n",
|
||
|
"Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Convert the active sentence to passive: 'The chef cooks the meal every day.' ### Response: The meal is prepared every day by the chef.<|endoftext|>The following is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Convert the active sentence to passive:\n",
|
||
|
"Ep 2 (Step 000120): Train loss 0.436, Val loss 0.672\n",
|
||
|
"Ep 2 (Step 000125): Train loss 0.451, Val loss 0.687\n",
|
||
|
"Ep 2 (Step 000130): Train loss 0.447, Val loss 0.682\n",
|
||
|
"Ep 2 (Step 000135): Train loss 0.405, Val loss 0.681\n",
|
||
|
"Ep 2 (Step 000140): Train loss 0.407, Val loss 0.680\n",
|
||
|
"Ep 2 (Step 000145): Train loss 0.370, Val loss 0.681\n",
|
||
|
"Ep 2 (Step 000150): Train loss 0.382, Val loss 0.676\n",
|
||
|
"Ep 2 (Step 000155): Train loss 0.413, Val loss 0.676\n",
|
||
|
"Ep 2 (Step 000160): Train loss 0.414, Val loss 0.685\n",
|
||
|
"Ep 2 (Step 000165): Train loss 0.379, Val loss 0.688\n",
|
||
|
"Ep 2 (Step 000170): Train loss 0.322, Val loss 0.683\n",
|
||
|
"Ep 2 (Step 000175): Train loss 0.338, Val loss 0.670\n",
|
||
|
"Ep 2 (Step 000180): Train loss 0.393, Val loss 0.659\n",
|
||
|
"Ep 2 (Step 000185): Train loss 0.417, Val loss 0.659\n",
|
||
|
"Ep 2 (Step 000190): Train loss 0.342, Val loss 0.649\n",
|
||
|
"Ep 2 (Step 000195): Train loss 0.330, Val loss 0.635\n",
|
||
|
"Ep 2 (Step 000200): Train loss 0.312, Val loss 0.634\n",
|
||
|
"Ep 2 (Step 000205): Train loss 0.355, Val loss 0.630\n",
|
||
|
"Ep 2 (Step 000210): Train loss 0.371, Val loss 0.629\n",
|
||
|
"Ep 2 (Step 000215): Train loss 0.394, Val loss 0.633\n",
|
||
|
"Ep 2 (Step 000220): Train loss 0.302, Val loss 0.646\n",
|
||
|
"Ep 2 (Step 000225): Train loss 0.344, Val loss 0.659\n",
|
||
|
"Ep 2 (Step 000230): Train loss 0.292, Val loss 0.656\n",
|
||
|
"Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Convert the active sentence to passive: 'The chef cooks the meal every day.' ### Response: The meal is cooked everyday by the chef.<|endoftext|>The following is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: What is the capital of the United Kingdom\n",
|
||
|
"Ep 3 (Step 000235): Train loss 0.327, Val loss 0.663\n",
|
||
|
"Ep 3 (Step 000240): Train loss 0.275, Val loss 0.693\n",
|
||
|
"Ep 3 (Step 000245): Train loss 0.275, Val loss 0.707\n",
|
||
|
"Ep 3 (Step 000250): Train loss 0.246, Val loss 0.698\n",
|
||
|
"Ep 3 (Step 000255): Train loss 0.277, Val loss 0.688\n",
|
||
|
"Ep 3 (Step 000260): Train loss 0.268, Val loss 0.687\n",
|
||
|
"Ep 3 (Step 000265): Train loss 0.269, Val loss 0.694\n",
|
||
|
"Ep 3 (Step 000270): Train loss 0.282, Val loss 0.707\n",
|
||
|
"Ep 3 (Step 000275): Train loss 0.275, Val loss 0.701\n",
|
||
|
"Ep 3 (Step 000280): Train loss 0.293, Val loss 0.709\n",
|
||
|
"Ep 3 (Step 000285): Train loss 0.291, Val loss 0.711\n",
|
||
|
"Ep 3 (Step 000290): Train loss 0.288, Val loss 0.710\n",
|
||
|
"Ep 3 (Step 000295): Train loss 0.268, Val loss 0.703\n",
|
||
|
"Ep 3 (Step 000300): Train loss 0.262, Val loss 0.691\n",
|
||
|
"Ep 3 (Step 000305): Train loss 0.268, Val loss 0.688\n",
|
||
|
"Ep 3 (Step 000310): Train loss 0.270, Val loss 0.692\n",
|
||
|
"Ep 3 (Step 000315): Train loss 0.234, Val loss 0.697\n",
|
||
|
"Ep 3 (Step 000320): Train loss 0.252, Val loss 0.696\n",
|
||
|
"Ep 3 (Step 000325): Train loss 0.235, Val loss 0.701\n",
|
||
|
"Ep 3 (Step 000330): Train loss 0.239, Val loss 0.697\n",
|
||
|
"Ep 3 (Step 000335): Train loss 0.229, Val loss 0.687\n",
|
||
|
"Ep 3 (Step 000340): Train loss 0.246, Val loss 0.684\n",
|
||
|
"Ep 3 (Step 000345): Train loss 0.243, Val loss 0.676\n",
|
||
|
"Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Convert the active sentence to passive: 'The chef cooks the meal every day.' ### Response: The chef cooks the meal every day.<|endoftext|>The following is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: What is the capital of the United Kingdom? \n",
|
||
|
"Training completed in 23.67 minutes.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import time\n",
|
||
|
"\n",
|
||
|
"start_time = time.time()\n",
|
||
|
"\n",
|
||
|
"torch.manual_seed(123)\n",
|
||
|
"\n",
|
||
|
"optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1)\n",
|
||
|
"\n",
|
||
|
"num_epochs = 3\n",
|
||
|
"\n",
|
||
|
"train_losses, val_losses, tokens_seen = train_model_simple(\n",
|
||
|
" model, train_loader, val_loader, optimizer, device,\n",
|
||
|
" num_epochs=num_epochs, eval_freq=5, eval_iter=5,\n",
|
||
|
" start_context=format_input(val_data[0]), tokenizer=tokenizer\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"end_time = time.time()\n",
|
||
|
"execution_time_minutes = (end_time - start_time) / 60\n",
|
||
|
"print(f\"Training completed in {execution_time_minutes:.2f} minutes.\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"id": "1Vdh7jmHI1we",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/",
|
||
|
"height": 307
|
||
|
},
|
||
|
"id": "1Vdh7jmHI1we",
|
||
|
"outputId": "97990a7a-605b-4634-9c6f-085d800eed71"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEiCAYAAAA21pHjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABg7klEQVR4nO3deVxU1fsH8M/MwGzs+w6yCYhsohBiqUmhmYpZmvl1yS0LM7+Wmt/K9VdamlpplplSWWlqmuUW7om4g4IgiSKgsqjIDjMwc35/XBkdWWQZmAGe9+t1XzD3nnvvc7jMPHPuOfdeHmOMgRBCCCE6ia/tAAghhBBSP0rUhBBCiA6jRE0IIYToMErUhBBCiA6jRE0IIYToMErUhBBCiA6jRE0IIYToMErUhBBCiA6jRE0IIYToMErUhHQgN27cAI/HQ2JiorZDIYRoCCVqQnQMj8drcFq4cKG2QySEtCE9bQdACFGXk5Oj+n3r1q2YP38+0tLSVPMMDQ21ERYhREuoRU2IjrG1tVVNJiYm4PF4qtfW1tZYuXIlHB0dIRKJEBgYiP3799e7LYVCgYkTJ8Lb2xtZWVkAgD/++AM9evSAWCyGm5sbFi1ahOrqatU6PB4PGzZswPDhwyGVSuHp6Yndu3erlt+/fx9jxoyBlZUVJBIJPD09sWnTpnpj2L59O/z8/CCRSGBhYYGIiAiUlZWplm/YsAE+Pj4Qi8Xw9vbG119/rbZ+dnY2Ro4cCVNTU5ibm2PYsGG4ceOGavmECRMQFRWFFStWwM7ODhYWFoiOjkZVVVWj/+aE6DRGCNFZmzZtYiYmJqrXK1euZMbGxuzXX39lV65cYXPmzGH6+vrs33//ZYwxlpGRwQCwhIQEVllZyYYPH86CgoJYfn4+Y4yx48ePM2NjYxYTE8OuXbvG/v77b9alSxe2cOFC1T4AMEdHR/bLL7+wq1evshkzZjBDQ0N27949xhhj0dHRLDAwkJ09e5ZlZGSw2NhYtnv37jrjv337NtPT02MrV65kGRkZ7NKlS2zt2rWspKSEMcbY5s2bmZ2dHduxYwe7fv0627FjBzM3N2cxMTGMMcbkcjnz8fFhEydOZJcuXWIpKSnstddeY15eXkwmkzHGGBs/fjwzNjZm06ZNY6mpqezPP/9kUqmUrV+/XrMHgxAtoURNiA57PFHb29uzjz/+WK1Mr1692FtvvcUYe5io//nnHzZgwADWp08fVlhYqCo7YMAA9sknn6it/9NPPzE7OzvVawDsww8/VL0uLS1lANi+ffsYY4wNGTKEvf76642K//z58wwAu3HjRp3L3d3d2S+//KI2b8mSJSwsLEwVm5eXF1MqlarlMpmMSSQSduDAAcYYl6hdXFxYdXW1qswrr7zCRo0a1agYCdF11EdNSDtRXFyM27dvIzw8XG1+eHg4Ll68qDZv9OjRcHR0xOHDhyGRSFTzL168iLi4OHz88ceqeQqFApWVlSgvL4dUKgUA+Pv7q5YbGBjA2NgY+fn5AIA333wTI0aMwIULF/D8888jKioKvXv3rjPmgIAADBgwAH5+foiMjMTzzz+Pl19+GWZmZigrK8O1a9cwadIkTJkyRbVOdXU1TExMVPGmp6fDyMhIbbuVlZW4du2a6rWvry8EAoHqtZ2dHZKSkhr4axLSflCiJqQDeuGFF7B582bEx8fj2WefVc0vLS3FokWL8NJLL9VaRywWq37X19dXW8bj8aBUKgEAgwYNQmZmJvbu3YvY2FgMGDAA0dHRWLFiRa1tCgQCxMbG4uTJk/j777/x1Vdf4YMPPsDp06dVXwq+++47hIaG1lqvJt7g4GD8/PPPtbZtZWXVqHgJae8oURPSThgbG8Pe3h5xcXHo27evan5cXBxCQkLUyr755pvo3r07hg4dij179qjK9+jRA2lpafDw8GhRLFZWVhg/fjzGjx+Pp59+GrNnz64zUQNc0gwPD0d4eDjmz58PFxcX7Ny5E7NmzYK9vT2uX7+OMWPG1Llujx49sHXrVlhbW8PY2LhFMRPSXlGiJqQdmT17NhYsWAB3d3cEBgZi06ZNSExMrLPF+fbbb0OhUODFF1/Evn370KdPH8yfPx8vvvginJ2d8fLLL4PP5+PixYtITk7G//3f/zUqhvnz5yM4OBi+vr6QyWT466+/4OPjU2fZ06dP49ChQ3j++edhbW2N06dP486dO6ryixYtwowZM2BiYoKBAwdCJpPh3LlzuH//PmbNmoUxY8Zg+fLlGDZsGBYvXgxHR0dkZmbi999/x5w5c+Do6Nj8PyYh7QQlakLakRkzZqCoqAjvvvsu8vPz0a1bN+zevRuenp51lp85cyaUSiVeeOEF7N+/H5GRkfjrr7+wePFifPrpp9DX14e3tzcmT57c6BiEQiHmzZuHGzduQCKR4Omnn8aWLVvqLGtsbIzjx49j9erVKC4uhouLCz7//HMMGjQIADB58mRIpVIsX74cs2fPhoGBAfz8/DBz5kwAgFQqxfHjxzF37ly89NJLKCkpgYODAwYMGEAtbNJp8BhjTNtBEEIIIaRudMMTQgghRIdRoiaEEEJ0GCVqQgghRIdRoiaEEEJ0GCVqQgghRIdRoiaEEEJ0GCXqZlq7di26dOkCsViM0NBQnDlzps32vXTpUvTq1QtGRkawtrZGVFSU2vOKAaBfv37g8Xhq07Rp09TKZGVlYfDgwZBKpbC2tsbs2bPVHncIAEePHkWPHj0gEong4eGBmJiYWvE092+xcOHCWjF6e3urlldWViI6OhoWFhYwNDTEiBEjkJeXp1N1qNGlS5dadeHxeIiOjgagu8fj+PHjGDJkCOzt7cHj8bBr1y615YwxzJ8/H3Z2dpBIJIiIiMDVq1fVyhQUFGDMmDEwNjaGqakpJk2ahNLSUrUyly5dwtNPPw2xWAwnJyd89tlntWLZtm0bvL29IRaL4efnh7179zYplobqUlVVhblz58LPzw8GBgawt7fHuHHjcPv2bbV91HUcly1b1qZ1edIxmTBhQq0YBw4cqHPH5En1qOv9wuPxsHz5cp06HjpBiw8Eabe2bNnChEIh27hxI7t8+TKbMmUKMzU1ZXl5eW2y/8jISLZp0yaWnJzMEhMT2QsvvMCcnZ1ZaWmpqkzfvn3ZlClTWE5OjmoqKipSLa+urmbdu3dnERERLCEhge3du5dZWlqyefPmqcpcv36dSaVSNmvWLJaSksK++uorJhAI2P79+zXyt1iwYAHz9fVVi/HOnTuq5dOmTWNOTk7s0KFD7Ny5c+ypp55ivXv31qk61MjPz1erR2xsLAPAjhw5otPHY+/eveyDDz5gv//+OwPAdu7cqbZ82bJlzMTEhO3atYtdvHiRDR06lLm6urKKigpVmYEDB7KAgAB26tQp9s8//zAPDw82evRo1fKioiJmY2PDxowZw5KTk9mvv/7KJBIJ+/bbb1Vl4uLimEAgYJ999hlLSUlhH374IdPX12dJSUmNjqWhuhQWFrKIiAi2detWduXKFRYfH89CQkJYcHCwWn1dXFzY4sWL1Y7To++rtqjLk47J+PHj2cCBA9ViLCgoUCujC8fkSfV4NP6cnBy2ceNGxuPx2LVr13TqeOgCStTNEBISwqKjo1WvFQoFs7e3Z0uXLtVKPPn5+QwAO3bsmGpe37592TvvvFPvOnv37mV8Pp/l5uaq5q1bt44ZGxurnvM7Z84c5uvrq7beqFGjWGRkpOp1S/4WCxYsYAEBAXUuKywsZPr6+mzbtm2qeampqQwAi4+P15k61Oedd95h7u7uqscztofj8fiHqVKpZLa2tmz58uWqeYWFhUwkErFff/2VMcZYSkoKA8DOnj2rKrNv3z7G4/HYrVu3GGOMff3118zMzExVD8YYmzt3LvPy8lK9HjlyJBs8eLBaPKGhoeyNN95odCwN1aUuZ86cYQBYZmamap6LiwtbtWpVveu0dV3qS9TDhg2rN0ZdPCaNOR7Dhg1jzz77rNo8XTse2kKnvptILpfj/PnziIiIUM3j8/mIiIhAfHy8VmIqKioCAJibm6vN//nnn2FpaYnu3btj3rx5KC8vVy2Lj4+
|
||
|
"text/plain": [
|
||
|
"<Figure size 500x300 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from previous_chapters import plot_losses\n",
|
||
|
"\n",
|
||
|
"epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
|
||
|
"plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "87b79a47-13f9-4d1f-87b1-3339bafaf2a3",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## 7.6 Saving the results"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"id": "F9QyvnRipwNc",
|
||
|
"metadata": {
|
||
|
"id": "F9QyvnRipwNc"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def extract_response(response_text):\n",
|
||
|
" return response[response.find(\"\\n### Response\")+len(\"\\n### Response:\")+1:]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"id": "VQ2NZMbfucAc",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "VQ2NZMbfucAc",
|
||
|
"outputId": "4a014e82-0741-4807-a77c-05b770940dd8"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
|
||
|
"\n",
|
||
|
"### Instruction:\n",
|
||
|
"Rewrite the sentence using a simile.\n",
|
||
|
"\n",
|
||
|
"### Input:\n",
|
||
|
"The car is very fast.\n",
|
||
|
"\n",
|
||
|
"Correct response:\n",
|
||
|
">> The car is as fast as lightning.\n",
|
||
|
"\n",
|
||
|
"Model response:\n",
|
||
|
">> The car is as fast as a bullet.\n",
|
||
|
"-------------------------------------\n",
|
||
|
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
|
||
|
"\n",
|
||
|
"### Instruction:\n",
|
||
|
"What type of cloud is typically associated with thunderstorms?\n",
|
||
|
"\n",
|
||
|
"Correct response:\n",
|
||
|
">> The type of cloud typically associated with thunderstorms is cumulonimbus.\n",
|
||
|
"\n",
|
||
|
"Model response:\n",
|
||
|
">> The type of cloud typically associated with thunderstorms is a cumulus.\n",
|
||
|
"-------------------------------------\n",
|
||
|
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
|
||
|
"\n",
|
||
|
"### Instruction:\n",
|
||
|
"Name the author of 'Pride and Prejudice'.\n",
|
||
|
"\n",
|
||
|
"Correct response:\n",
|
||
|
">> Jane Austen.\n",
|
||
|
"\n",
|
||
|
"Model response:\n",
|
||
|
">> The author of 'Pride and Prejudice' is Jane Austen.\n",
|
||
|
"-------------------------------------\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"torch.manual_seed(123)\n",
|
||
|
"\n",
|
||
|
"for entry in test_data[:3]:\n",
|
||
|
"\n",
|
||
|
" input_text = format_input(entry)\n",
|
||
|
"\n",
|
||
|
" token_ids = generate(\n",
|
||
|
" model=model,\n",
|
||
|
" idx=text_to_token_ids(input_text, tokenizer).to(device),\n",
|
||
|
" max_new_tokens=256,\n",
|
||
|
" context_size=BASE_CONFIG[\"context_length\"],\n",
|
||
|
" eos_id=50256\n",
|
||
|
" )\n",
|
||
|
" response = token_ids_to_text(token_ids, tokenizer)\n",
|
||
|
" response_text = extract_response(response)\n",
|
||
|
"\n",
|
||
|
" print(input_text)\n",
|
||
|
" print(f\"\\nCorrect response:\\n>> {entry['output']}\")\n",
|
||
|
" print(f\"\\nModel response:\\n>> {response_text.strip()}\")\n",
|
||
|
" print(\"-------------------------------------\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 31,
|
||
|
"id": "-PNGKzY4snKP",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "-PNGKzY4snKP",
|
||
|
"outputId": "b065c0e6-a3b3-4e70-bbfd-17ff69ad317f"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100%|█████████████████████████████████████████| 110/110 [06:24<00:00, 3.50s/it]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from tqdm import tqdm\n",
|
||
|
"\n",
|
||
|
"for i, entry in tqdm(enumerate(test_data), total=len(test_data)):\n",
|
||
|
"\n",
|
||
|
" input_text = format_input(entry)\n",
|
||
|
"\n",
|
||
|
" token_ids = generate(\n",
|
||
|
" model=model,\n",
|
||
|
" idx=text_to_token_ids(input_text, tokenizer).to(device),\n",
|
||
|
" max_new_tokens=256,\n",
|
||
|
" context_size=BASE_CONFIG[\"context_length\"],\n",
|
||
|
" eos_id=50256\n",
|
||
|
" )\n",
|
||
|
" response = token_ids_to_text(token_ids, tokenizer)\n",
|
||
|
" response_text = extract_response(response)\n",
|
||
|
"\n",
|
||
|
" test_data[i][\"model_response\"] = response_text\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"with open(\"instruction-data-with-response.json\", \"w\") as file:\n",
|
||
|
" json.dump(test_data, file, indent=4) # \"indent\" for pretty-printing"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 32,
|
||
|
"id": "u-AvCCMTnPSE",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "u-AvCCMTnPSE",
|
||
|
"outputId": "6968bb22-04e5-4473-90bc-4ed6af6aa0cf"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'instruction': 'Rewrite the sentence using a simile.',\n",
|
||
|
" 'input': 'The car is very fast.',\n",
|
||
|
" 'output': 'The car is as fast as lightning.',\n",
|
||
|
" 'model_response': 'The car is as fast as a bullet.'}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 32,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"test_data[0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 33,
|
||
|
"id": "8cBU0iHmVfOI",
|
||
|
"metadata": {
|
||
|
"id": "8cBU0iHmVfOI"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model saved as gpt2-medium355M-sft.pth\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import re\n",
|
||
|
"\n",
|
||
|
"file_name = f\"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft.pth\"\n",
|
||
|
"torch.save(model.state_dict(), file_name)\n",
|
||
|
"print(f\"Model saved as {file_name}\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "obgoGI89dgPm",
|
||
|
"metadata": {
|
||
|
"id": "obgoGI89dgPm"
|
||
|
},
|
||
|
"source": [
|
||
|
"## 7.7 Evaluating the finetuned LLM"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 34,
|
||
|
"id": "026e8570-071e-48a2-aa38-64d7be35f288",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"True\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import psutil\n",
|
||
|
"\n",
|
||
|
"def check_if_running(process_name):\n",
|
||
|
" running = False\n",
|
||
|
" for proc in psutil.process_iter([\"name\"]):\n",
|
||
|
" if process_name in proc.info[\"name\"]:\n",
|
||
|
" running = True\n",
|
||
|
" break\n",
|
||
|
" return running\n",
|
||
|
"\n",
|
||
|
"ollama_running = check_if_running(\"ollama\")\n",
|
||
|
"\n",
|
||
|
"if not ollama_running:\n",
|
||
|
" raise RuntimeError(\"Ollama not running. Launch ollama before proceeding.\")\n",
|
||
|
"print(check_if_running(\"ollama\"))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 35,
|
||
|
"id": "e3ae0e10-2b28-42ce-8ea2-d9366a58088f",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Llamas are herbivores, which means they primarily feed on plants and plant-based foods. Their diet typically consists of:\n",
|
||
|
"\n",
|
||
|
"1. Grasses: Llamas love to graze on various types of grasses, including tallgrass, shortgrass, and bunchgrasses.\n",
|
||
|
"2. Leaves: They enjoy munching on leaves from trees and shrubs, such as oak, maple, and willow.\n",
|
||
|
"3. Fruits: Llamas enjoy fruits like apples, berries, and melons.\n",
|
||
|
"4. Hay: A good quality hay, such as timothy or alfalfa, is often provided as a staple in their diet.\n",
|
||
|
"5. Grains: Whole grains like oats, barley, and corn can be offered as treats or as part of their regular feed.\n",
|
||
|
"6. Supplements: In some cases, llama owners may choose to add commercial supplements or mineral blocks to ensure the animal is getting all the necessary nutrients.\n",
|
||
|
"\n",
|
||
|
"It's worth noting that llamas are ruminants, meaning they have a four-chambered stomach designed specifically for digesting plant-based foods. Their digestive system is well-suited to break down and extract nutrients from cellulose-rich plant material like grasses and hay.\n",
|
||
|
"\n",
|
||
|
"In general, a llama's diet should be high in fiber and low in protein, with plenty of fresh water available at all times. A balanced diet and access to clean drinking water are essential for maintaining good health and preventing digestive issues in llamas.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import urllib.request\n",
|
||
|
"\n",
|
||
|
"def query_model(prompt, model=\"llama3\", url=\"http://localhost:11434/api/chat\"):\n",
|
||
|
" # Create the data payload as a dictionary\n",
|
||
|
" data = {\n",
|
||
|
" \"model\": model,\n",
|
||
|
" \"seed\": 123, # for deterministic responses\n",
|
||
|
" \"temperature\": 0, # for deterministic responses\n",
|
||
|
" \"messages\": [\n",
|
||
|
" {\"role\": \"user\", \"content\": prompt}\n",
|
||
|
" ]\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" # Convert the dictionary to a JSON formatted string and encode it to bytes\n",
|
||
|
" payload = json.dumps(data).encode(\"utf-8\")\n",
|
||
|
"\n",
|
||
|
" # Create a request object, setting the method to POST and adding necessary headers\n",
|
||
|
" request = urllib.request.Request(url, data=payload, method=\"POST\")\n",
|
||
|
" request.add_header(\"Content-Type\", \"application/json\")\n",
|
||
|
"\n",
|
||
|
" # Send the request and capture the response\n",
|
||
|
" response_data = \"\"\n",
|
||
|
" with urllib.request.urlopen(request) as response:\n",
|
||
|
" # Read and decode the response\n",
|
||
|
" while True:\n",
|
||
|
" line = response.readline().decode(\"utf-8\")\n",
|
||
|
" if not line:\n",
|
||
|
" break\n",
|
||
|
" response_json = json.loads(line)\n",
|
||
|
" response_data += response_json[\"message\"][\"content\"]\n",
|
||
|
"\n",
|
||
|
" return response_data\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"model = \"llama3\"\n",
|
||
|
"result = query_model(\"What do Llamas eat?\", model)\n",
|
||
|
"print(result)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "207ae28f-0f8c-4fda-aeef-e7e3046249cc",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"- Using ollama with the `\"llama3\"` model (a 8B parameter model) requires 16 GB of RAM; if this is not supported by your machine, you can try the smaller model, such as the 3.8B parameter phi-3 model by setting `model = \"phi-3\"`, which only requires 8 Gb of RAM"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 36,
|
||
|
"id": "86b839d4-064d-4178-b2d7-01691b452e5e",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"Dataset response:\n",
|
||
|
">> The car is as fast as lightning.\n",
|
||
|
"\n",
|
||
|
"Model response:\n",
|
||
|
">> The car is as fast as a bullet.\n",
|
||
|
"\n",
|
||
|
"Score:\n",
|
||
|
">> A fun task!\n",
|
||
|
"\n",
|
||
|
"To score this response, I'll consider the following factors:\n",
|
||
|
"\n",
|
||
|
"1. Grammar and syntax: The sentence is grammatically correct.\n",
|
||
|
"2. Simile quality: A bullet is a relatively fast-moving object, making it a decent comparison for a fast car.\n",
|
||
|
"3. Originality: While not extremely original, the comparison to a bullet is a common simile used to describe speed.\n",
|
||
|
"\n",
|
||
|
"Score: 85\n",
|
||
|
"\n",
|
||
|
"Reasoning: The response is good but not outstanding. Using a bullet as a simile for speed is a classic and understandable choice. However, it's not particularly creative or surprising, which is why I wouldn't give it a perfect score of 100. Overall, the response effectively completes the instruction and conveys the idea that the car is fast.\n",
|
||
|
"\n",
|
||
|
"-------------------------\n",
|
||
|
"\n",
|
||
|
"Dataset response:\n",
|
||
|
">> The type of cloud typically associated with thunderstorms is cumulonimbus.\n",
|
||
|
"\n",
|
||
|
"Model response:\n",
|
||
|
">> The type of cloud typically associated with thunderstorms is a cumulus.\n",
|
||
|
"\n",
|
||
|
"Score:\n",
|
||
|
">> A nice evaluation!\n",
|
||
|
"\n",
|
||
|
"Let's compare the model response to the correct output:\n",
|
||
|
"\n",
|
||
|
"Model Response: \"The type of cloud typically associated with thunderstorms is a cumulus.\"\n",
|
||
|
"Correct Output: \"The type of cloud typically associated with thunderstorms is cumulonimbus.\"\n",
|
||
|
"\n",
|
||
|
"To score the model response, I'll consider the following factors:\n",
|
||
|
"\n",
|
||
|
"1. Accuracy: The model response is close but not entirely accurate. Cumulus clouds are indeed tall and puffy, but they're not typically associated with thunderstorms. Cumulonimbus clouds are the ones commonly linked to severe weather.\n",
|
||
|
"Score: 60/100 (it's a good guess, but not precise)\n",
|
||
|
"\n",
|
||
|
"2. Relevance: The model response is somewhat relevant to the question. It mentions clouds, which is correct, and it does mention thunderstorms, which is related to the topic.\n",
|
||
|
"Score: 40/100 (it's on the right track, but not entirely focused)\n",
|
||
|
"\n",
|
||
|
"3. Clarity: The model response is clear and easy to understand.\n",
|
||
|
"Score: 80/100 (good job on that front!)\n",
|
||
|
"\n",
|
||
|
"Overall Score: (60 + 40 + 80) / 3 = 66.67\n",
|
||
|
"\n",
|
||
|
"I'd give the model response a score of **66** out of 100. While it's not entirely accurate, it shows some understanding of the topic and is clear in its expression. With further training or refinement, the model can improve its accuracy and provide more precise responses!\n",
|
||
|
"\n",
|
||
|
"-------------------------\n",
|
||
|
"\n",
|
||
|
"Dataset response:\n",
|
||
|
">> Jane Austen.\n",
|
||
|
"\n",
|
||
|
"Model response:\n",
|
||
|
">> The author of 'Pride and Prejudice' is Jane Austen.\n",
|
||
|
"\n",
|
||
|
"Score:\n",
|
||
|
">> Based on the input and expected output, I would respond as follows:\n",
|
||
|
"\n",
|
||
|
"### Model Response:\n",
|
||
|
"The author of 'Pride and Prejudice' is Jane Austen.\n",
|
||
|
"\n",
|
||
|
"**Score:** 100/100\n",
|
||
|
"\n",
|
||
|
"Reasoning: The model response accurately completes the instruction by stating the correct author of the novel \"Pride and Prejudice\", which is indeed Jane Austen. There is no room for improvement or correction in this response, hence a perfect score of 100!\n",
|
||
|
"\n",
|
||
|
"-------------------------\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"for entry in test_data[:3]:\n",
|
||
|
" prompt = (\n",
|
||
|
" f\"Given the input `{format_input(entry)}` \"\n",
|
||
|
" f\"and correct output `{entry['output']}`, \"\n",
|
||
|
" f\"score the model response `{entry['model_response']}`\"\n",
|
||
|
" f\" on a scale from 0 to 100, where 100 is the best score. \"\n",
|
||
|
" )\n",
|
||
|
" print(\"\\nDataset response:\")\n",
|
||
|
" print(\">>\", entry['output'])\n",
|
||
|
" print(\"\\nModel response:\")\n",
|
||
|
" print(\">>\", entry[\"model_response\"])\n",
|
||
|
" print(\"\\nScore:\")\n",
|
||
|
" print(\">>\", query_model(prompt))\n",
|
||
|
" print(\"\\n-------------------------\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"id": "9d7bca69-97c4-47a5-9aa0-32f116fa37eb",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Scoring entries: 100%|████████████████████████| 110/110 [00:46<00:00, 2.39it/s]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Number of scores: 110 of 110\n",
|
||
|
"Average score: 48.98\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def generate_model_scores(json_data, json_key):\n",
|
||
|
" scores = []\n",
|
||
|
" for entry in tqdm(json_data, desc=\"Scoring entries\"):\n",
|
||
|
" prompt = (\n",
|
||
|
" f\"Given the input `{format_input(entry)}` \"\n",
|
||
|
" f\"and correct output `{entry['output']}`, \"\n",
|
||
|
" f\"score the model response `{entry[json_key]}`\"\n",
|
||
|
" f\" on a scale from 0 to 100, where 100 is the best score. \"\n",
|
||
|
" f\"Respond with the integer number only.\"\n",
|
||
|
" )\n",
|
||
|
" score = query_model(prompt)\n",
|
||
|
" try:\n",
|
||
|
" scores.append(int(score))\n",
|
||
|
" except ValueError:\n",
|
||
|
" print(f\"Could not convert score: {score}\")\n",
|
||
|
" continue\n",
|
||
|
"\n",
|
||
|
" return scores\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"scores = generate_model_scores(test_data, \"model_response\")\n",
|
||
|
"print(f\"Number of scores: {len(scores)} of {len(test_data)}\")\n",
|
||
|
"print(f\"Average score: {sum(scores)/len(scores):.2f}\\n\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "412d7325-284a-446c-92a1-5aa8acc52dee",
|
||
|
"metadata": {
|
||
|
"id": "xczdTl40ajob"
|
||
|
},
|
||
|
"source": [
|
||
|
"## 7.8 Conclusions"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "f9853e7f-a81a-4806-9728-be1690807185",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Summary"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"accelerator": "GPU",
|
||
|
"colab": {
|
||
|
"gpuType": "L4",
|
||
|
"machine_shape": "hm",
|
||
|
"provenance": []
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.14"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|