mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-15 12:12:10 +00:00
Exercise solutions (#237)
This commit is contained in:
parent
06921f3333
commit
b90c7ad2d6
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,5 +1,4 @@
|
||||
# Configs and keys
|
||||
ch07/01_main-chapter-code/instruction-data-with-response-standalone.json
|
||||
ch07/01_main-chapter-code/gpt2-medium355M-sft-standalone.pth
|
||||
ch07/02_dataset-utilities/config.json
|
||||
ch07/03_model-evaluation/config.json
|
||||
@ -68,6 +67,9 @@ ch06/03_bonus_imdb-classification/test.csv
|
||||
ch06/03_bonus_imdb-classification/train.csv
|
||||
ch06/03_bonus_imdb-classification/validation.csv
|
||||
|
||||
ch07/01_main-chapter-code/instruction-data-with-response-standalone.json
|
||||
ch07/01_main-chapter-code/instruction-data-with-response-baseline.json
|
||||
ch07/01_main-chapter-code/instruction-data-with-response-mask-instructions.json
|
||||
ch07/02_dataset-utilities/instruction-examples-modified.json
|
||||
|
||||
# Temporary OS-related files
|
||||
|
@ -56,7 +56,7 @@ Alternatively, you can view this and other files on GitHub at [https://github.co
|
||||
| Ch 4: Implementing a GPT Model from Scratch | - [ch04.ipynb](ch04/01_main-chapter-code/ch04.ipynb)<br/>- [gpt.py](ch04/01_main-chapter-code/gpt.py) (summary)<br/>- [exercise-solutions.ipynb](ch04/01_main-chapter-code/exercise-solutions.ipynb) | [./ch04](./ch04) |
|
||||
| Ch 5: Pretraining on Unlabeled Data | - [ch05.ipynb](ch05/01_main-chapter-code/ch05.ipynb)<br/>- [gpt_train.py](ch05/01_main-chapter-code/gpt_train.py) (summary) <br/>- [gpt_generate.py](ch05/01_main-chapter-code/gpt_generate.py) (summary) <br/>- [exercise-solutions.ipynb](ch05/01_main-chapter-code/exercise-solutions.ipynb) | [./ch05](./ch05) |
|
||||
| Ch 6: Finetuning for Text Classification | - [ch06.ipynb](ch06/01_main-chapter-code/ch06.ipynb) <br/>- [gpt_class_finetune.py](ch06/01_main-chapter-code/gpt_class_finetune.py) <br/>- [exercise-solutions.ipynb](ch06/01_main-chapter-code/exercise-solutions.ipynb) | [./ch06](./ch06) |
|
||||
| Ch 7: Finetuning to Follow Instructions | - [ch07.ipynb](ch07/01_main-chapter-code/ch07.ipynb) | [./ch07](./ch07) |
|
||||
| Ch 7: Finetuning to Follow Instructions | - [ch07.ipynb](ch07/01_main-chapter-code/ch07.ipynb)<br/>- [gpt_instruction_finetuning.py](ch07/01_main-chapter-code/gpt_instruction_finetuning.py)<br/>- [ollama_evaluate.py](ch07/01_main-chapter-code/ollama_evaluate.py)<br/>- [exercise-solutions.ipynb](ch07/01_main-chapter-code/exercise-solutions.ipynb) | [./ch07](./ch07) |
|
||||
| Appendix A: Introduction to PyTorch | - [code-part1.ipynb](appendix-A/01_main-chapter-code/code-part1.ipynb)<br/>- [code-part2.ipynb](appendix-A/01_main-chapter-code/code-part2.ipynb)<br/>- [DDP-script.py](appendix-A/01_main-chapter-code/DDP-script.py)<br/>- [exercise-solutions.ipynb](appendix-A/01_main-chapter-code/exercise-solutions.ipynb) | [./appendix-A](./appendix-A) |
|
||||
| Appendix B: References and Further Reading | No code | - |
|
||||
| Appendix C: Exercise Solutions | No code | - |
|
||||
|
@ -5,6 +5,8 @@
|
||||
- [ch07.ipynb](ch07.ipynb) contains all the code as it appears in the chapter
|
||||
- [previous_chapters.py](previous_chapters.py) is a Python module that contains the GPT model we coded and trained in previous chapters, alongside many utility functions, which we reuse in this chapter
|
||||
- [gpt_download.py](gpt_download.py) contains the utility functions for downloading the pretrained GPT model weights
|
||||
- [exercise-solutions.ipynb](exercise-solutions.ipynb) contains the exercise solutions for this chapter
|
||||
|
||||
|
||||
### Optional Code
|
||||
|
||||
@ -71,3 +73,4 @@ Number of scores: 110 of 110
|
||||
Average score: 51.75
|
||||
```
|
||||
|
||||
- [exercise_experiments.py](exercise_experiments.py) is an optional scropt that implements the exercise solutions; for more details see [exercise-solutions.ipynb](exercise-solutions.ipynb)
|
||||
|
899
ch07/01_main-chapter-code/exercise-solutions.ipynb
Normal file
899
ch07/01_main-chapter-code/exercise-solutions.ipynb
Normal file
@ -0,0 +1,899 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ba450fb1-8a26-4894-ab7a-5d7bfefe90ce",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<table style=\"width:100%\">\n",
|
||||
"<tr>\n",
|
||||
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
||||
"<font size=\"2\">\n",
|
||||
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
|
||||
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
|
||||
"</font>\n",
|
||||
"</td>\n",
|
||||
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
||||
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
|
||||
"</td>\n",
|
||||
"</tr>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "51c9672d-8d0c-470d-ac2d-1271f8ec3f14",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Chapter 7 Exercise solutions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2625ddc4-9cce-42bd-947d-4e2203fdc55c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Exercise 7.1: Changing prompt styles"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6be25a95-2a33-433b-a698-2365b5fc9357",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Suppose we have the following data entry:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"instruction\": \"Identify the correct spelling of the following word.\",\n",
|
||||
" \"input\": \"Ocassion\",\n",
|
||||
" \"output\": \"The correct spelling is 'Occasion.'\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"In the main chapter, we formatted it according to the Alpaca-style prompt template:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
|
||||
"\n",
|
||||
"### Instruction:\n",
|
||||
"Identify the correct spelling of the following word.\n",
|
||||
"\n",
|
||||
"### Input:\n",
|
||||
"Occassion\n",
|
||||
"\n",
|
||||
"### Response:\n",
|
||||
"The correct spelling is 'Occasion.'\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"In this exercise, we now use the Phi-3 prompt template instead, which formats the data entry as follows:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"<user>\n",
|
||||
"Identify the correct spelling of the following word: 'Occasion'\n",
|
||||
"\n",
|
||||
"<assistant>\n",
|
||||
"The correct spelling is 'Occasion'.\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Note that this prompt template is substantially shorter, which reduces the runtime and hardware requirements for finetuning the LLM and generating text since the input prompts are shorter.\n",
|
||||
"To make this change, we update the `format_input` function as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f99baa1e-c24c-417f-89d0-13e6d061ea6a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def format_input(entry):\n",
|
||||
" instruction_text = (\n",
|
||||
" f\"<|user|>\\n{entry['instruction']}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" input_text = f\"\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
|
||||
"\n",
|
||||
" return instruction_text + input_text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e4ba538f-64b9-495d-847b-d9f1d324bc50",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's make sure that it works as intended by applying it to two input samples, one with and one without content in the `'input'` field:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "877a57e2-535f-4363-b32a-a093edd951b8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<|user|>\n",
|
||||
"Identify the correct spelling of the following word.\n",
|
||||
"Ocassion\n",
|
||||
"\n",
|
||||
"<|user|>\n",
|
||||
"What is an antonym of 'complicated'?\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sample_data = [\n",
|
||||
" {'instruction': 'Identify the correct spelling of the following word.', 'input': 'Ocassion', 'output': \"The correct spelling is 'Occasion.'\"}, \n",
|
||||
" {'instruction': \"What is an antonym of 'complicated'?\", 'input': '', 'output': \"An antonym of 'complicated' is 'simple'.\"}\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(format_input(sample_data[0]))\n",
|
||||
"print()\n",
|
||||
"print(format_input(sample_data[1]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fa2a6704-6c61-4a09-b8f5-ffc5a77d6aa3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we also update the `InstructionDataset` class to use the <|assistant|> prompt template for the response:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "17f1a42c-7cc0-4746-8a6d-3a4cb37e2ca1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tiktoken\n",
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"\n",
|
||||
"class InstructionDataset(Dataset):\n",
|
||||
" def __init__(self, data, tokenizer):\n",
|
||||
" self.data = data\n",
|
||||
"\n",
|
||||
" # Pre-tokenize texts\n",
|
||||
" self.encoded_texts = []\n",
|
||||
" for entry in data:\n",
|
||||
"\n",
|
||||
" ###################################################################\n",
|
||||
" # NEW: Use `format_input_phi` and adjust the response text template\n",
|
||||
" instruction_plus_input = format_input(entry)\n",
|
||||
" response_text = f\"\\n<|assistant|>:\\n{entry['output']}\"\n",
|
||||
" ###################################################################\n",
|
||||
" full_text = instruction_plus_input + response_text\n",
|
||||
" self.encoded_texts.append(\n",
|
||||
" tokenizer.encode(full_text)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" return self.encoded_texts[index]\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tokenizer = tiktoken.get_encoding(\"gpt2\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b319df48-8a14-4beb-966c-6b49e255c6ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition, we have to make sure that we use the `format_input_phi` function: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ed41ec0e-6a5a-44f1-bb3c-b7d0724cb4e1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```python\n",
|
||||
" train_losses, val_losses, tokens_seen = train_model_simple(\n",
|
||||
" model, train_loader, val_loader, optimizer, device,\n",
|
||||
" num_epochs=num_epochs, eval_freq=5, eval_iter=5,\n",
|
||||
" start_context=format_input_phi(val_data[0]), # New: `Use format_input_phi` function\n",
|
||||
" tokenizer=tokenizer\n",
|
||||
" )\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e0650926-c39f-4442-8116-cb7494416f28",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Lastly, we also have to update the way we extract the generated response when we collect the test set responses:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a9253041-812f-4a5f-9ab1-d7e4cb1407fb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```python\n",
|
||||
" for i, entry in tqdm(enumerate(test_data), total=len(test_data)):\n",
|
||||
"\n",
|
||||
" input_text = format_input(entry)\n",
|
||||
" tokenizer=tokenizer\n",
|
||||
"\n",
|
||||
" token_ids = generate(\n",
|
||||
" model=model,\n",
|
||||
" idx=text_to_token_ids(input_text, tokenizer).to(device),\n",
|
||||
" max_new_tokens=256,\n",
|
||||
" context_size=BASE_CONFIG[\"context_length\"],\n",
|
||||
" eos_id=50256\n",
|
||||
" )\n",
|
||||
" generated_text = token_ids_to_text(token_ids, tokenizer)\n",
|
||||
"\n",
|
||||
" # New: Adjust ###Response -> <|assistant|>\n",
|
||||
" response_text = generated_text[len(input_text):].replace(\"<|assistant|>:\", \"\").strip()\n",
|
||||
"\n",
|
||||
" test_data[i][\"model_response\"] = response_text\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "29cd557c-3838-45e4-a26a-baed4b11175a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For your convenience, the exercise solution is implemented in the [exercise_experiments.py](exercise_experiments.py) script, which you can run as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dd8158e9-cc70-4e0f-88b0-73c3e1d8c030",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```python\n",
|
||||
"python exercise_experiments.py --exercise_solution phi3_prompt\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Output:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"matplotlib version: 3.7.1\n",
|
||||
"tiktoken version: 0.7.0\n",
|
||||
"torch version: 2.3.0+cu121\n",
|
||||
"tqdm version: 4.66.4\n",
|
||||
"tensorflow version: 2.15.0\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Training set length: 935\n",
|
||||
"Validation set length: 55\n",
|
||||
"Test set length: 110\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Device: cuda\n",
|
||||
"--------------------------------------------------\n",
|
||||
"...\n",
|
||||
"Loaded model: gpt2-medium (355M)\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Initial losses\n",
|
||||
" Training loss: 3.71630220413208\n",
|
||||
" Validation loss: 3.6440994262695314\n",
|
||||
"Ep 1 (Step 000000): Train loss 2.633, Val loss 2.622\n",
|
||||
"...\n",
|
||||
"Ep 2 (Step 000230): Train loss 0.424, Val loss 0.928\n",
|
||||
"<|user|> Convert the active sentence to passive: 'The chef cooks the meal every day.' <|assistant|>: The meal is prepared every day by the chef....\n",
|
||||
"Training completed in 1.50 minutes.\n",
|
||||
"Plot saved as loss-plot-phi3-prompt.pdf\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Generating responses\n",
|
||||
"100% 110/110 [00:11<00:00, 9.27it/s]\n",
|
||||
"Responses saved as instruction-data-with-response-phi3-prompt.json\n",
|
||||
"Model saved as gpt2-medium355M-sft-phi3-prompt.pth\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"For comparison, you can run the original chapter 7 finetuning code via `python exercise_experiments.py --exercise_solution baseline`. \n",
|
||||
"\n",
|
||||
"Note that on an Nvidia L4 GPU, the code above, using the Phi-3 prompt template, takes 1.5 min to run. In comparison, the Alpaca-style template takes 1.80 minutes to run. So, the Phi-3 template is approximately 17% faster since it results in shorter model inputs. \n",
|
||||
"\n",
|
||||
"Let's take a look at some of the responses to make sure they have been formatted correctly:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
" {\n",
|
||||
" \"instruction\": \"Rewrite the sentence using a simile.\",\n",
|
||||
" \"input\": \"The car is very fast.\",\n",
|
||||
" \"output\": \"The car is as fast as lightning.\",\n",
|
||||
" \"model_response\": \"The car is as fast as a cheetah.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"instruction\": \"What type of cloud is typically associated with thunderstorms?\",\n",
|
||||
" \"input\": \"\",\n",
|
||||
" \"output\": \"The type of cloud typically associated with thunderstorms is cumulonimbus.\",\n",
|
||||
" \"model_response\": \"The type of cloud associated with thunderstorms is a cumulus cloud.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"instruction\": \"Name the author of 'Pride and Prejudice'.\",\n",
|
||||
" \"input\": \"\",\n",
|
||||
" \"output\": \"Jane Austen.\",\n",
|
||||
" \"model_response\": \"The author of 'Pride and Prejudice' is Jane Austen.\"\n",
|
||||
" },\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"We can evaluate the performance using the Ollama Llama 3 method, which is for your convenience, also implemented in the `python exercise_experiments.py` script, which we can run as follows:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"python ollama_evaluate.py --file_path instruction-data-with-response-phi3-prompt.json\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Output:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"Ollama running: True\n",
|
||||
"Scoring entries: 100%|████████████████████████| 110/110 [01:08<00:00, 1.60it/s]\n",
|
||||
"Number of scores: 110 of 110\n",
|
||||
"Average score: 48.87\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The score is close to 50, which is in the same ballpark as the score we previously achieved with the Alpaca-style prompts.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## Exercise 7.2: Instruction and input masking\n",
|
||||
"\n",
|
||||
"To mask out the instructions as shown in the following figure, we need to make slight modifications to the `InstructionDataset` class and `custom_collate_fn`.\n",
|
||||
"\n",
|
||||
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch07_compressed/mask-instructions.webp\" width=600px>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "4405196a-db81-470b-be39-167a059587b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This `format_input` function is copied from the original chapter 7 code\n",
|
||||
"\n",
|
||||
"def format_input(entry):\n",
|
||||
" instruction_text = (\n",
|
||||
" f\"Below is an instruction that describes a task. \"\n",
|
||||
" f\"Write a response that appropriately completes the request.\"\n",
|
||||
" f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
|
||||
"\n",
|
||||
" return instruction_text + input_text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "83658c09-af8a-425a-b940-eb1f06e43c0b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can modify the `InstructionDataset` class to collect the lengths of the instructions, which we will use in the collate function to locate the instruction content positions in the targets when we code the collate function, as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "e5e6188a-f182-4f26-b9e5-ccae3ecadae0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class InstructionDataset(Dataset):\n",
|
||||
" def __init__(self, data, tokenizer):\n",
|
||||
" self.data = data\n",
|
||||
"\n",
|
||||
" ##########################################################################################\n",
|
||||
" # New: Separate list for instruction lengths\n",
|
||||
" self.instruction_lengths = []\n",
|
||||
" ##########################################################################################\n",
|
||||
" \n",
|
||||
" self.encoded_texts = []\n",
|
||||
" \n",
|
||||
" for entry in data:\n",
|
||||
" instruction_plus_input = format_input(entry)\n",
|
||||
" response_text = f\"\\n\\n### Response:\\n{entry['output']}\"\n",
|
||||
" full_text = instruction_plus_input + response_text\n",
|
||||
" \n",
|
||||
" self.encoded_texts.append(\n",
|
||||
" tokenizer.encode(full_text)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" ##########################################################################################\n",
|
||||
" # New: collect instruction lengths\n",
|
||||
" instruction_length = len(tokenizer.encode(instruction_plus_input))\n",
|
||||
" self.instruction_lengths.append(instruction_length)\n",
|
||||
" ##########################################################################################\n",
|
||||
" \n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" # New: return both instruction lengths and texts separately\n",
|
||||
" return self.instruction_lengths[index], self.encoded_texts[index]\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "0163b7d1-acb8-456c-8efe-86307b58f4bb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"tokenizer = tiktoken.get_encoding(\"gpt2\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3a186394-4960-424d-bb6a-f58459dd5994",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we update the `custom_collate_fn` where each `batch` is now a tuple containing `(instruction_length, item)` instead of just `item` due to the changes in the `InstructionDataset` dataset. In addition, we now mask the corresponding instruction tokens in the target ID list."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "f815e6fc-8e54-4105-aecd-d4c6e890ff9d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def custom_collate_fn(\n",
|
||||
" batch,\n",
|
||||
" pad_token_id=50256,\n",
|
||||
" ignore_index=-100,\n",
|
||||
" allowed_max_length=None,\n",
|
||||
" device=\"cpu\"\n",
|
||||
"):\n",
|
||||
" # Find the longest sequence in the batch\n",
|
||||
" batch_max_length = max(len(item)+1 for instruction_length, item in batch) # New: batch is now a tuple\n",
|
||||
"\n",
|
||||
" # Pad and prepare inputs and targets\n",
|
||||
" inputs_lst, targets_lst = [], []\n",
|
||||
"\n",
|
||||
" for instruction_length, item in batch: # New: batch is now a tuple\n",
|
||||
" new_item = item.copy()\n",
|
||||
" # Add an <|endoftext|> token\n",
|
||||
" new_item += [pad_token_id]\n",
|
||||
" # Pad sequences to max_length\n",
|
||||
" padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))\n",
|
||||
" inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs\n",
|
||||
" targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets\n",
|
||||
"\n",
|
||||
" # Replace all but the first padding tokens in targets by ignore_index\n",
|
||||
" mask = targets == pad_token_id\n",
|
||||
" indices = torch.nonzero(mask).squeeze()\n",
|
||||
" if indices.numel() > 1:\n",
|
||||
" targets[indices[1:]] = ignore_index\n",
|
||||
"\n",
|
||||
" ##########################################################################################\n",
|
||||
" # New: Mask all input and instruction tokens in the targets\n",
|
||||
" targets[:instruction_length-1] = -100\n",
|
||||
" ##########################################################################################\n",
|
||||
" \n",
|
||||
" # Optionally truncate to maximum sequence length\n",
|
||||
" if allowed_max_length is not None:\n",
|
||||
" inputs = inputs[:allowed_max_length]\n",
|
||||
" targets = targets[:allowed_max_length]\n",
|
||||
" \n",
|
||||
" inputs_lst.append(inputs)\n",
|
||||
" targets_lst.append(targets)\n",
|
||||
"\n",
|
||||
" # Convert list of inputs and targets to tensors and transfer to target device\n",
|
||||
" inputs_tensor = torch.stack(inputs_lst).to(device)\n",
|
||||
" targets_tensor = torch.stack(targets_lst).to(device)\n",
|
||||
"\n",
|
||||
" return inputs_tensor, targets_tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0a4a4815-850e-42c4-b70d-67e8ce5ebd57",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's try it out on some sample data below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "8da8a5b1-a8e2-4389-b21c-25b67be6dd1c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sample_data = [\n",
|
||||
" {'instruction': \"What is an antonym of 'complicated'?\", 'input': '', 'output': \"An antonym of 'complicated' is 'simple'.\"},\n",
|
||||
" {'instruction': 'Sort the following list in alphabetical order.', 'input': 'Zebra, Elephant, Crocodile', 'output': 'Crocodile, Elephant, Zebra'},\n",
|
||||
" {'instruction': 'Arrange the given numbers in descending order.', 'input': '5, 12, 8, 3, 15', 'output': '15, 12, 8, 5, 3.'}\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "435b0816-0fc8-4650-a84a-eceffa4d85e4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"\n",
|
||||
"train_dataset = InstructionDataset(sample_data, tokenizer)\n",
|
||||
"train_loader = DataLoader(\n",
|
||||
" train_dataset,\n",
|
||||
" batch_size=len(sample_data),\n",
|
||||
" collate_fn=custom_collate_fn,\n",
|
||||
" num_workers=0\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "106bbbd7-7286-4eb6-b343-43419332a80f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train loader:\n",
|
||||
"torch.Size([3, 64]) torch.Size([3, 64])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Train loader:\")\n",
|
||||
"for inputs, targets in train_loader:\n",
|
||||
" print(inputs.shape, targets.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "9bb3288b-84a9-4962-ae59-a7a29fd34bce",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Inputs:\n",
|
||||
" tensor([21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430,\n",
|
||||
" 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198,\n",
|
||||
" 21017, 46486, 25, 198, 42758, 262, 1708, 1351, 287, 24830,\n",
|
||||
" 605, 1502, 13, 198, 198, 21017, 23412, 25, 198, 57,\n",
|
||||
" 37052, 11, 42651, 11, 9325, 19815, 576, 198, 198, 21017,\n",
|
||||
" 18261, 25, 198, 34, 12204, 375, 576, 11, 42651, 11,\n",
|
||||
" 1168, 37052, 50256, 50256])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Targets:\n",
|
||||
" tensor([ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
|
||||
" -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
|
||||
" -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
|
||||
" -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
|
||||
" -100, -100, -100, -100, -100, -100, 198, 198, 21017, 18261,\n",
|
||||
" 25, 198, 34, 12204, 375, 576, 11, 42651, 11, 1168,\n",
|
||||
" 37052, 50256, -100, -100])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Inputs:\\n\", inputs[1])\n",
|
||||
"print(\"\\n\\nTargets:\\n\", targets[1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cc40347b-2ca7-44e1-862d-0fd0c92f0628",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As we can see based on the `targets` tensor, both the instruction and padding tokens are now masked using the -100 placeholder tokens. \n",
|
||||
"Let's decode the inputs just to make sure that they look correct:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "76a9e6fa-3d75-4e39-b139-c3e05048f42b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
|
||||
"\n",
|
||||
"### Instruction:\n",
|
||||
"Sort the following list in alphabetical order.\n",
|
||||
"\n",
|
||||
"### Input:\n",
|
||||
"Zebra, Elephant, Crocodile\n",
|
||||
"\n",
|
||||
"### Response:\n",
|
||||
"Crocodile, Elephant, Zebra<|endoftext|><|endoftext|>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(tokenizer.decode(list(inputs[1])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "845ebd36-f63f-4b58-a76e-7767e4d2ccbd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, let's decode the non-masked target token IDS:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "4d54a152-b778-455a-8941-e375e2a17e8f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"### Response:\n",
|
||||
"Crocodile, Elephant, Zebra<|endoftext|>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"non_masked_targets = targets[1][targets[1] != -100]\n",
|
||||
"\n",
|
||||
"print(tokenizer.decode(list(non_masked_targets)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3912bbf5-e9e2-474b-9552-d522e7510aa6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As shown above, the non-masked target tokens exclude the `\"Instruction\"` and `\"Input\"` fields, as intended. Now, we can run the modified code to see how well the LLM performs when finetuned using this masking strategy.\n",
|
||||
"\n",
|
||||
"For your convenience, you can use the `exercise_experiments.py` code to run a comparison as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "56a76097-9114-479d-8803-443b0ff48581",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```bash\n",
|
||||
"python exercise_experiments.py --exercise_solution mask_instructions\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Output:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"matplotlib version: 3.7.1\n",
|
||||
"tiktoken version: 0.7.0\n",
|
||||
"torch version: 2.3.0+cu121\n",
|
||||
"tqdm version: 4.66.4\n",
|
||||
"tensorflow version: 2.15.0\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Training set length: 935\n",
|
||||
"Validation set length: 55\n",
|
||||
"Test set length: 110\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Device: cuda\n",
|
||||
"--------------------------------------------------\n",
|
||||
"...\n",
|
||||
"Loaded model: gpt2-medium (355M)\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Initial losses\n",
|
||||
" Training loss: 2.280539035797119\n",
|
||||
" Validation loss: 2.262560224533081\n",
|
||||
"Ep 1 (Step 000000): Train loss 1.636, Val loss 1.620\n",
|
||||
"...\n",
|
||||
"Ep 2 (Step 000230): Train loss 0.143, Val loss 0.727\n",
|
||||
"...\n",
|
||||
"Training completed in 1.77 minutes.\n",
|
||||
"Plot saved as loss-plot-mask-instructions.pdf\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Generating responses\n",
|
||||
"100% 110/110 [02:10<00:00, 1.19s/it]\n",
|
||||
"Responses saved as instruction-data-with-response-mask-instructions.json\n",
|
||||
"Model saved as gpt2-medium355M-sft-mask-instructions.pth\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Next, let's evaluate the performance of the resulting LLM:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"python ollama_evaluate.py --file_path instruction-data-with-response-mask-instructions.json\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"Ollama running: True\n",
|
||||
"Scoring entries: 100%|██████████████████████████████████████████████████████████████████████████████████████| 110/110 [01:23<00:00, 1.31it/s]\n",
|
||||
"Number of scores: 110 of 110\n",
|
||||
"Average score: 47.73\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"As we can see based on the scores, the instruction masking does perform slightly worse, which is consistent with the observation in the \"Instruction Tuning With Loss Over Instructions\" paper (https://arxiv.org/abs/2405.14394)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "94a0f758-29da-44ee-b7af-32473b3c086e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"## Exercise 7.3: Finetuning on the original Alpaca dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "68df7616-679f-4e53-954d-6e7cf2e2ef55",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To finetune the model on the original Stanford Alpaca dataset ([https://github.com/tatsu-lab/stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca)), you just need to change the file URL from\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json\"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"to\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"url = \"https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json\"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Note that the dataset contains 52k entries (50x more than in chapter 7), and the entries are longer than the ones we worked with in chapter 7.\n",
|
||||
"Thus, it's highly recommended that the training be run on a GPU.\n",
|
||||
"\n",
|
||||
"If you encounter out-of-memory errors, consider reducing the batch size from 8 to 4, 2, or 1. In addition to lowering the batch size, you may also want to consider lowering the `allowed_max_length` from 1024 to 512 or 256."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d94c9621-2c3f-4551-b5b8-87cd96e38c9c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For your convenience, you can use the `exercise_experiments.py` code to finetune the model on the 52k Alpaca dataset with a batch size of 4 and an `allowed_max_length` of 512 as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "40a76486-73e6-4415-94dc-bfe2aa36ea52",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```bash\n",
|
||||
"python exercise_experiments.py --exercise_solution alpaca_52k\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"matplotlib version: 3.7.1\n",
|
||||
"tiktoken version: 0.7.0\n",
|
||||
"torch version: 2.3.0+cu121\n",
|
||||
"tqdm version: 4.66.4\n",
|
||||
"tensorflow version: 2.15.0\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Training set length: 44201\n",
|
||||
"Validation set length: 2601\n",
|
||||
"Test set length: 5200\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Device: cuda\n",
|
||||
"--------------------------------------------------\n",
|
||||
"...\n",
|
||||
"Loaded model: gpt2-medium (355M)\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Initial losses\n",
|
||||
" Training loss: 3.3681655883789063\n",
|
||||
" Validation loss: 3.4122894287109373\n",
|
||||
"Ep 1 (Step 000000): Train loss 2.477, Val loss 2.750\n",
|
||||
"...\n",
|
||||
"Ep 2 (Step 022095): Train loss 0.761, Val loss 1.557\n",
|
||||
"...\n",
|
||||
"Training completed in 196.38 minutes.\n",
|
||||
"Plot saved as loss-plot-alpaca52k.pdf\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Generating responses\n",
|
||||
"100% 5200/5200 [2:56:33<00:00, 2.04s/it]\n",
|
||||
"Responses saved as instruction-data-with-response-alpaca52k.json\n",
|
||||
"Model saved as gpt2-medium355M-sft-alpaca52k.pth\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cc8a1bd1-15f4-4a35-87b5-369038db0aa7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Below are a few examples from the Alpaca dataset, including the generated model responses:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "be9ab66f-5819-4b01-9a03-c45aa3b7c5b8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```\n",
|
||||
"[\n",
|
||||
" {\n",
|
||||
" \"instruction\": \"Edit the following sentence to increase readability: \\\"He made a huge effort and was so successful.\\\"\",\n",
|
||||
" \"input\": \"\",\n",
|
||||
" \"output\": \"He exerted a tremendous effort, and thus enjoyed great success.\",\n",
|
||||
" \"model_response\": \"He put in an immense effort and was rewarded with success.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"instruction\": \"Rewrite the following sentence to make it more concise: \\\"I was displeased with the result of the experiment that I conducted.\\\"\",\n",
|
||||
" \"input\": \"\",\n",
|
||||
" \"output\": \"I was unhappy with my experiment's outcome.\",\n",
|
||||
" \"model_response\": \"I was displeased with the results of the experiment.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"instruction\": \"How can we build a more efficient GPT model?\",\n",
|
||||
" \"input\": \"\",\n",
|
||||
" \"output\": \"We can build a more efficient GPT model by optimizing the architecture of the model, using smaller model sizes and training with fewer parameters. We can also leverage techniques such as knowledge distillation, transfer learning, dynamic sparsity and hybrid computing to further improve the efficiency of the model.\",\n",
|
||||
" \"model_response\": \"Building a more efficient GPT model requires careful planning and optimization. First, it is important to identify the target language and the context in which the model is used. Then, it is important to select the appropriate model architecture, such as backpropagation, hyperparameters, and hyperparameters. Finally, it is important to select the appropriate model weights and optimizers, such as backpropagation, hyperparameters, and hyperparameters.\"\n",
|
||||
" },\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "840e2076-f2e6-44a2-86fd-f191f9674267",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, we can evaluate the finetuned LLM using the [ollama_evaluate.py](ollama_evaluate.py) utility function:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"python ollama_evaluate.py --file_path instruction-data-with-response-alpaca52k.json\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"Scoring entries: 100%|████████████████████| 5200/5200 [1:07:52<00:00, 1.28it/s]\n",
|
||||
"Number of scores: 5188 of 5200\n",
|
||||
"Average score: 48.16\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d14b3c60-00a1-43a9-9fcd-592aaadf1ef4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The score is slightly lower than the score we obtained on the dataset we used in this chapter. However, note that the Alpaca test set contains more diverse and partly more challenging instructions than the dataset we used in the main chapter."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
502
ch07/01_main-chapter-code/exercise_experiments.py
Normal file
502
ch07/01_main-chapter-code/exercise_experiments.py
Normal file
@ -0,0 +1,502 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
#
|
||||
# Code to run the exercises; see exercise-solutions.ipynb for more information
|
||||
|
||||
from functools import partial
|
||||
from importlib.metadata import version
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import urllib
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import tiktoken
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
# Import from local files in this folder
|
||||
from gpt_download import download_and_load_gpt2
|
||||
from previous_chapters import (
|
||||
calc_loss_loader,
|
||||
generate,
|
||||
GPTModel,
|
||||
load_weights_into_gpt,
|
||||
text_to_token_ids,
|
||||
train_model_simple,
|
||||
token_ids_to_text
|
||||
)
|
||||
|
||||
|
||||
class InstructionDataset(Dataset):
|
||||
def __init__(self, data, tokenizer):
|
||||
self.data = data
|
||||
|
||||
# Pre-tokenize texts
|
||||
self.encoded_texts = []
|
||||
for entry in data:
|
||||
instruction_plus_input = format_input(entry)
|
||||
response_text = f"\n\n### Response:\n{entry['output']}"
|
||||
full_text = instruction_plus_input + response_text
|
||||
self.encoded_texts.append(
|
||||
tokenizer.encode(full_text)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.encoded_texts[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class InstructionDatasetWithMasking(Dataset):
|
||||
def __init__(self, data, tokenizer):
|
||||
self.data = data
|
||||
|
||||
# New: Separate list for instruction lengths
|
||||
self.instruction_lengths = []
|
||||
self.encoded_texts = []
|
||||
|
||||
for entry in data:
|
||||
instruction_plus_input = format_input(entry)
|
||||
response_text = f"\n\n### Response:\n{entry['output']}"
|
||||
full_text = instruction_plus_input + response_text
|
||||
|
||||
self.encoded_texts.append(
|
||||
tokenizer.encode(full_text)
|
||||
)
|
||||
|
||||
# New: collect instruction lengths
|
||||
instruction_length = len(tokenizer.encode(instruction_plus_input))
|
||||
self.instruction_lengths.append(instruction_length)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# New: return both instruction lengths and texts separately
|
||||
return self.instruction_lengths[index], self.encoded_texts[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class InstructionDatasetPhi(Dataset):
|
||||
def __init__(self, data, tokenizer):
|
||||
self.data = data
|
||||
|
||||
# Pre-tokenize texts
|
||||
self.encoded_texts = []
|
||||
for entry in data:
|
||||
|
||||
###################################################################
|
||||
# NEW: Use `format_input_phi` and adjust the response text template
|
||||
instruction_plus_input = format_input_phi(entry)
|
||||
response_text = f"\n<|assistant|>:\n{entry['output']}"
|
||||
###################################################################
|
||||
full_text = instruction_plus_input + response_text
|
||||
self.encoded_texts.append(
|
||||
tokenizer.encode(full_text)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.encoded_texts[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def custom_collate_fn(
|
||||
batch,
|
||||
pad_token_id=50256,
|
||||
ignore_index=-100,
|
||||
allowed_max_length=None,
|
||||
device="cpu"
|
||||
):
|
||||
# Find the longest sequence in the batch
|
||||
batch_max_length = max(len(item)+1 for item in batch)
|
||||
|
||||
# Pad and prepare inputs and targets
|
||||
inputs_lst, targets_lst = [], []
|
||||
|
||||
for item in batch:
|
||||
new_item = item.copy()
|
||||
# Add an <|endoftext|> token
|
||||
new_item += [pad_token_id]
|
||||
# Pad sequences to max_length
|
||||
padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))
|
||||
inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs
|
||||
targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets
|
||||
|
||||
# New: Replace all but the first padding tokens in targets by ignore_index
|
||||
mask = targets == pad_token_id
|
||||
indices = torch.nonzero(mask).squeeze()
|
||||
if indices.numel() > 1:
|
||||
targets[indices[1:]] = ignore_index
|
||||
|
||||
# New: Optionally truncate to maximum sequence length
|
||||
if allowed_max_length is not None:
|
||||
inputs = inputs[:allowed_max_length]
|
||||
targets = targets[:allowed_max_length]
|
||||
|
||||
inputs_lst.append(inputs)
|
||||
targets_lst.append(targets)
|
||||
|
||||
# Convert list of inputs and targets to tensors and transfer to target device
|
||||
inputs_tensor = torch.stack(inputs_lst).to(device)
|
||||
targets_tensor = torch.stack(targets_lst).to(device)
|
||||
|
||||
return inputs_tensor, targets_tensor
|
||||
|
||||
|
||||
def custom_collate_with_masking_fn(
|
||||
batch,
|
||||
pad_token_id=50256,
|
||||
ignore_index=-100,
|
||||
allowed_max_length=None,
|
||||
device="cpu"
|
||||
):
|
||||
# Find the longest sequence in the batch
|
||||
batch_max_length = max(len(item)+1 for instruction_length, item in batch) # New: batch is now a tuple
|
||||
|
||||
# Pad and prepare inputs and targets
|
||||
inputs_lst, targets_lst = [], []
|
||||
|
||||
for instruction_length, item in batch: # New: batch is now a tuple
|
||||
new_item = item.copy()
|
||||
# Add an <|endoftext|> token
|
||||
new_item += [pad_token_id]
|
||||
# Pad sequences to max_length
|
||||
padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))
|
||||
inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs
|
||||
targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets
|
||||
|
||||
# Replace all but the first padding tokens in targets by ignore_index
|
||||
mask = targets == pad_token_id
|
||||
indices = torch.nonzero(mask).squeeze()
|
||||
if indices.numel() > 1:
|
||||
targets[indices[1:]] = ignore_index
|
||||
|
||||
# New: Mask all input and instruction tokens in the targets
|
||||
targets[:instruction_length-1] = -100
|
||||
|
||||
# Optionally truncate to maximum sequence length
|
||||
if allowed_max_length is not None:
|
||||
inputs = inputs[:allowed_max_length]
|
||||
targets = targets[:allowed_max_length]
|
||||
|
||||
inputs_lst.append(inputs)
|
||||
targets_lst.append(targets)
|
||||
|
||||
# Convert list of inputs and targets to tensors and transfer to target device
|
||||
inputs_tensor = torch.stack(inputs_lst).to(device)
|
||||
targets_tensor = torch.stack(targets_lst).to(device)
|
||||
|
||||
return inputs_tensor, targets_tensor
|
||||
|
||||
|
||||
def download_and_load_file(file_path, url):
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
with urllib.request.urlopen(url) as response:
|
||||
text_data = response.read().decode("utf-8")
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(text_data)
|
||||
else:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
text_data = file.read()
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
data = json.load(file)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def format_input_phi(entry):
|
||||
instruction_text = (
|
||||
f"<|user|>\n{entry['instruction']}"
|
||||
)
|
||||
|
||||
input_text = f"\n{entry['input']}" if entry["input"] else ""
|
||||
|
||||
return instruction_text + input_text
|
||||
|
||||
|
||||
def format_input(entry):
|
||||
instruction_text = (
|
||||
f"Below is an instruction that describes a task. "
|
||||
f"Write a response that appropriately completes the request."
|
||||
f"\n\n### Instruction:\n{entry['instruction']}"
|
||||
)
|
||||
|
||||
input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""
|
||||
|
||||
return instruction_text + input_text
|
||||
|
||||
|
||||
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, plot_name):
|
||||
fig, ax1 = plt.subplots(figsize=(12, 6))
|
||||
|
||||
# Plot training and validation loss against epochs
|
||||
ax1.plot(epochs_seen, train_losses, label="Training loss")
|
||||
ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
|
||||
ax1.set_xlabel("Epochs")
|
||||
ax1.set_ylabel("Loss")
|
||||
ax1.legend(loc="upper right")
|
||||
|
||||
# Create a second x-axis for tokens seen
|
||||
ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
|
||||
ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
|
||||
ax2.set_xlabel("Tokens seen")
|
||||
|
||||
fig.tight_layout() # Adjust layout to make room
|
||||
print(f"Plot saved as {plot_name}")
|
||||
plt.savefig(plot_name)
|
||||
# plt.show()
|
||||
|
||||
|
||||
def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
|
||||
#######################################
|
||||
# Print package versions
|
||||
#######################################
|
||||
print()
|
||||
pkgs = [
|
||||
"matplotlib", # Plotting library
|
||||
"tiktoken", # Tokenizer
|
||||
"torch", # Deep learning library
|
||||
"tqdm", # Progress bar
|
||||
"tensorflow", # For OpenAI's pretrained weights
|
||||
]
|
||||
for p in pkgs:
|
||||
print(f"{p} version: {version(p)}")
|
||||
print(50*"-")
|
||||
|
||||
#######################################
|
||||
# Download and prepare dataset
|
||||
#######################################
|
||||
file_path = "instruction-data.json"
|
||||
|
||||
if alpaca52k:
|
||||
url = "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json"
|
||||
else:
|
||||
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json"
|
||||
data = download_and_load_file(file_path, url)
|
||||
|
||||
train_portion = int(len(data) * 0.85) # 85% for training
|
||||
test_portion = int(len(data) * 0.1) # 10% for testing
|
||||
|
||||
train_data = data[:train_portion]
|
||||
test_data = data[train_portion:train_portion + test_portion]
|
||||
val_data = data[train_portion + test_portion:]
|
||||
|
||||
print("Training set length:", len(train_data))
|
||||
print("Validation set length:", len(val_data))
|
||||
print("Test set length:", len(test_data))
|
||||
print(50*"-")
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("Device:", device)
|
||||
print(50*"-")
|
||||
|
||||
if alpaca52k:
|
||||
allowed_max_length = 512
|
||||
else:
|
||||
allowed_max_length = 1024
|
||||
|
||||
if mask_instructions and phi3_prompt:
|
||||
raise ValueError("Simultaneous support for instruction masking and the Phi-3 prompt template has not been implemented, yet.")
|
||||
|
||||
if mask_instructions:
|
||||
customized_collate_fn = partial(custom_collate_with_masking_fn, device=device, allowed_max_length=allowed_max_length)
|
||||
CustomDataset = InstructionDatasetWithMasking
|
||||
elif phi3_prompt:
|
||||
customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=allowed_max_length)
|
||||
CustomDataset = InstructionDatasetPhi
|
||||
else:
|
||||
customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=allowed_max_length)
|
||||
CustomDataset = InstructionDataset
|
||||
|
||||
num_workers = 0
|
||||
|
||||
if alpaca52k:
|
||||
batch_size = 4
|
||||
else:
|
||||
batch_size = 8
|
||||
|
||||
torch.manual_seed(123)
|
||||
|
||||
train_dataset = CustomDataset(train_data, tokenizer)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=customized_collate_fn,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
val_dataset = CustomDataset(val_data, tokenizer)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=customized_collate_fn,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
#######################################
|
||||
# Load pretrained model
|
||||
#######################################
|
||||
BASE_CONFIG = {
|
||||
"vocab_size": 50257, # Vocabulary size
|
||||
"context_length": 1024, # Context length
|
||||
"drop_rate": 0.0, # Dropout rate
|
||||
"qkv_bias": True # Query-key-value bias
|
||||
}
|
||||
|
||||
model_configs = {
|
||||
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
||||
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
||||
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
||||
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
||||
}
|
||||
|
||||
CHOOSE_MODEL = "gpt2-medium (355M)"
|
||||
|
||||
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
||||
|
||||
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
|
||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||
|
||||
model = GPTModel(BASE_CONFIG)
|
||||
load_weights_into_gpt(model, params)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
print("Loaded model:", CHOOSE_MODEL)
|
||||
print(50*"-")
|
||||
|
||||
#######################################
|
||||
# Finetuning the model
|
||||
#######################################
|
||||
print("Initial losses")
|
||||
with torch.no_grad():
|
||||
train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)
|
||||
val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)
|
||||
|
||||
print(" Training loss:", train_loss)
|
||||
print(" Validation loss:", val_loss)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
num_epochs = 2
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1)
|
||||
|
||||
torch.manual_seed(123)
|
||||
|
||||
start_context = format_input_phi(val_data[0]) if phi3_prompt else format_input(val_data[0])
|
||||
|
||||
train_losses, val_losses, tokens_seen = train_model_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=num_epochs, eval_freq=5, eval_iter=5,
|
||||
start_context=start_context, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time_minutes = (end_time - start_time) / 60
|
||||
print(f"Training completed in {execution_time_minutes:.2f} minutes.")
|
||||
|
||||
epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
|
||||
|
||||
plot_name = "loss-plot.pdf"
|
||||
if mask_instructions:
|
||||
plot_name = plot_name.replace(".pdf", "-mask-instructions.pdf")
|
||||
if alpaca52k:
|
||||
plot_name = plot_name.replace(".pdf", "-alpaca52k.pdf")
|
||||
if phi3_prompt:
|
||||
plot_name = plot_name.replace(".pdf", "-phi3-prompt.pdf")
|
||||
if not any([mask_instructions, alpaca52k, phi3_prompt]):
|
||||
plot_name = plot_name.replace(".pdf", "-baseline.pdf")
|
||||
|
||||
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, plot_name)
|
||||
print(50*"-")
|
||||
|
||||
#######################################
|
||||
# Saving results
|
||||
#######################################
|
||||
print("Generating responses")
|
||||
for i, entry in tqdm(enumerate(test_data), total=len(test_data)):
|
||||
|
||||
input_text = format_input_phi(entry) if phi3_prompt else format_input(entry)
|
||||
|
||||
token_ids = generate(
|
||||
model=model,
|
||||
idx=text_to_token_ids(input_text, tokenizer).to(device),
|
||||
max_new_tokens=256,
|
||||
context_size=BASE_CONFIG["context_length"],
|
||||
eos_id=50256
|
||||
)
|
||||
generated_text = token_ids_to_text(token_ids, tokenizer)
|
||||
|
||||
if phi3_prompt:
|
||||
response_text = generated_text[len(input_text):].replace("<|assistant|>:", "").strip()
|
||||
else:
|
||||
response_text = generated_text[len(input_text):].replace("### Response:", "").strip()
|
||||
|
||||
test_data[i]["model_response"] = response_text
|
||||
|
||||
test_data_path = "instruction-data-with-response.json"
|
||||
file_name = f"{re.sub(r'[ ()]', '', CHOOSE_MODEL) }-sft.pth"
|
||||
|
||||
if mask_instructions:
|
||||
test_data_path = test_data_path.replace(".json", "-mask-instructions.json")
|
||||
file_name = file_name.replace(".pth", "-mask-instructions.pth")
|
||||
if alpaca52k:
|
||||
test_data_path = test_data_path.replace(".json", "-alpaca52k.json")
|
||||
file_name = file_name.replace(".pth", "-alpaca52k.pth")
|
||||
if phi3_prompt:
|
||||
test_data_path = test_data_path.replace(".json", "-phi3-prompt.json")
|
||||
file_name = file_name.replace(".pth", "-phi3-prompt.pth")
|
||||
if not any([mask_instructions, alpaca52k, phi3_prompt]):
|
||||
test_data_path = test_data_path.replace(".json", "-baseline.json")
|
||||
file_name = file_name.replace(".pth", "-baseline.pth")
|
||||
|
||||
with open(test_data_path, "w") as file:
|
||||
json.dump(test_data, file, indent=4) # "indent" for pretty-printing
|
||||
print(f"Responses saved as {test_data_path}")
|
||||
|
||||
torch.save(model.state_dict(), file_name)
|
||||
print(f"Model saved as {file_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Instruction finetune a GPT model"
|
||||
)
|
||||
options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt"}
|
||||
parser.add_argument(
|
||||
"--exercise_solution",
|
||||
type=str,
|
||||
default="last_block",
|
||||
help=(
|
||||
f"Which experiment to run. Options: {options}."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.exercise_solution == "baseline":
|
||||
main()
|
||||
elif args.exercise_solution == "mask_instructions":
|
||||
main(mask_instructions=True)
|
||||
elif args.exercise_solution == "alpaca_52k":
|
||||
main(alpaca52k=True)
|
||||
elif args.exercise_solution == "phi3_prompt":
|
||||
main(phi3_prompt=True)
|
||||
else:
|
||||
raise ValueError(f"{args.exercise_solution} is not a valid --args.exercise_solution option. Options: {options}")
|
@ -83,6 +83,9 @@ def main(file_path):
|
||||
def generate_model_scores(json_data, json_key, model="llama3"):
|
||||
scores = []
|
||||
for entry in tqdm(json_data, desc="Scoring entries"):
|
||||
if entry[json_key] == "":
|
||||
scores.append(0)
|
||||
else:
|
||||
prompt = (
|
||||
f"Given the input `{format_input(entry)}` "
|
||||
f"and correct output `{entry['output']}`, "
|
||||
@ -105,7 +108,7 @@ if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Instruction finetune a GPT model"
|
||||
description="Evaluate model responses with ollama"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file_path",
|
||||
|
Loading…
x
Reference in New Issue
Block a user