"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"# Direct Preference Optimization (DPO) for LLM Alignment (From Scratch)"
]
},
{
"cell_type": "markdown",
"id": "d04cb2b8-d87b-4c6b-a225-c630d758f68e",
"metadata": {
"id": "d04cb2b8-d87b-4c6b-a225-c630d758f68e"
},
"source": [
"- This code notebook implements Direct Preference Optimization (DPO) from scratch and applies it to a large language model (LLM) to enhance its ability to generate responses that align more closely with user preferences"
"- DPO, proposed in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290), is an alternative to reinforcement learning from human feedback (RLHF) used in finetuning large language models (LLMs)\n",
"- DPO can be used to finetune (or align) the model to generate responses that better align with user expectations and instructions\n",
"- In instruction finetuning, we train the LLM to generate correct answers given a prompt\n",
"- However, in practice, there are multiple ways to give a correct answer, and correct answers can differ in style; for example, consider a technical and a more user-friendly response when asking an LLM to give recommendations when buying a laptop, as shown in the figure below\n",
"- RLHF and DPO are methods that can be used to teach the LLM to prefer one answer style over the other, that is, aligning better with user preferences\n",
"- The RLHF process, which requires training a separate reward model, is outlined below\n",
"- Compared to RLHF, DPO aims to simplify the process by optimizing models directly for user preferences without the need for complex reward modeling and policy optimization\n",
"- In other words, DPO focuses on directly optimizing the model's output to align with human preferences or specific objectives\n",
"- Shown below is the main idea as an overview of how DPO works\n",
"- The concrete equation to implement the DPO loss is shown below; we will revisit the equation when we implement it in Python further down in this code notebook\n",
" - \"expected value\" $\\mathbb{E}$ is statistics jargon and stands for the average or mean value of the random variable (the expression inside the brackets)\n",
" - The $\\pi_{\\theta}$ variable is the so-called policy (a term borrowed from reinforcement learning) and represents the LLM we want to optimize; $\\pi_{ref}$ is a reference LLM, which is typically the original LLM before optimization (at the beginning of the training, $\\pi_{\\theta}$ and $\\pi_{ref}$ are typically the same)\n",
" - $\\beta$ is a hyperparameter to control the divergence between the $\\pi_{\\theta}$ and the reference model; increasing $\\beta$ increases the impact of the difference between\n",
"$\\pi_{\\theta}$ and $\\pi_{ref}$ in terms of their log probabilities on the overall loss function, thereby increasing the divergence between the two models\n",
"- To avoid bloating the code notebook with a more detailed discussion, I may write a separate standalone article with more details on these concepts in the future\n",
"- In the meantime, if you are interested in comparing RLHF and DPO, please see the section [2.2. RLHF vs Direct Preference Optimization (DPO)](https://magazine.sebastianraschka.com/i/142924793/rlhf-vs-direct-preference-optimization-dpo) in my article [Tips for LLM Pretraining and Evaluating Reward Models](https://magazine.sebastianraschka.com/p/tips-for-llm-pretraining-and-evaluating-rms)"
]
},
{
"cell_type": "markdown",
"id": "xqVAgsyQ6LuG",
"metadata": {
"id": "xqVAgsyQ6LuG",
"tags": []
},
"source": [
" \n",
"# 2) Preparing a preference dataset for DPO"
]
},
{
"cell_type": "markdown",
"id": "60b2195d-8734-469b-a52e-5031ca7ea6b1",
"metadata": {
"id": "60b2195d-8734-469b-a52e-5031ca7ea6b1"
},
"source": [
"- Let's begin by loading and preparing the dataset, which may already answer a lot of the questions you might have before we revisit the DPO loss equation\n",
"- Here, we work with a dataset that contains more polite and less polite responses to instruction prompts (concrete examples are shown in the next section)\n",
"- The dataset was generated via the [create-preference-data-ollama.ipynb](create-preference-data-ollama.ipynb) notebook"
"{'instruction': \"What is an antonym of 'complicated'?\",\n",
" 'input': '',\n",
" 'output': \"An antonym of 'complicated' is 'simple'.\",\n",
" 'chosen': \"A suitable antonym for 'complicated' would be 'simple'.\",\n",
" 'rejected': \"An antonym of 'complicated' is 'simple'.\"}\n"
]
}
],
"source": [
"pprint.pp(data[999])"
]
},
{
"cell_type": "markdown",
"id": "56db5697-a089-4b40-a1f3-e928e8018220",
"metadata": {
"id": "56db5697-a089-4b40-a1f3-e928e8018220"
},
"source": [
"\n",
"\n",
"```\n",
"# This is formatted as code\n",
"```\n",
"\n",
"- As we can see above, the dataset consists of 5 keys:\n",
" - The `'instruction'` and `'input'` that are used as LLM inputs\n",
" - The `'output'` contains the response the model was trained on via the instruction finetuning step in chapter 7\n",
" - the `'chosen'` and `'rejected'` entries are the entries we use for DPO; here `'chosen'` is the preferred response, and `'rejected'` is the dispreferred response\n",
"- The goal is to get the model to follow the style of the chosen over the rejected responses"
]
},
{
"cell_type": "markdown",
"id": "86257468-a6ab-4ba3-9c9f-2fdc2c0cc284",
"metadata": {
"id": "86257468-a6ab-4ba3-9c9f-2fdc2c0cc284"
},
"source": [
"- Below is a utility function that formats the model input by applying the Alpaca prompt style similar to chapter 7 ([../01_main-chapter-code/ch07.ipynb](../01_main-chapter-code/ch07.ipynb)):"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4564d55c-1c5d-46a6-b5e8-46ab568ad627",
"metadata": {
"id": "4564d55c-1c5d-46a6-b5e8-46ab568ad627"
},
"outputs": [],
"source": [
"def format_input(entry):\n",
" instruction_text = (\n",
" f\"Below is an instruction that describes a task. \"\n",
" f\"Write a response that appropriately completes the request.\"\n",
"print(\"Training set length:\", len(train_data))\n",
"print(\"Validation set length:\", len(val_data))\n",
"print(\"Test set length:\", len(test_data))"
]
},
{
"cell_type": "markdown",
"id": "c07d09f7-66af-49ed-8b9e-484f46e6a68d",
"metadata": {
"id": "c07d09f7-66af-49ed-8b9e-484f46e6a68d"
},
"source": [
" \n",
"## 2.3) Developing a `PreferenceDataset` class and batch processing function"
]
},
{
"cell_type": "markdown",
"id": "86101174-00c8-485d-8273-d086d5311926",
"metadata": {
"id": "86101174-00c8-485d-8273-d086d5311926"
},
"source": [
"- In this section, we rewrite the `InstructionDataset` class from chapter 7 ([../01_main-chapter-code/ch07.ipynb](../01_main-chapter-code/ch07.ipynb)) for DPO\n",
"- This means that instead of focusing on single output sequences (responses), we modify the dataset class to return pairs of responses where one is preferred (\"chosen\") over the other (\"rejected\")\n",
"- Overall, the `PreferenceDataset` is almost identical to the `InstructionDataset` used in chapter 7:"
"- Along with an updated `PreferenceDataset` class, we also need an updated batch collation function that we use to pad the sequences in each batch to an equal length so that we can assemble them in batches\n",
"- I added comments to the code below to illustrate the process; however, it might be easiest to understand how it works by looking at the example inputs and outputs further below:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8d3a43a6-7704-4bff-9bbc-a38632374f30",
"metadata": {
"id": "8d3a43a6-7704-4bff-9bbc-a38632374f30"
},
"outputs": [],
"source": [
"def custom_collate_fn(\n",
" batch,\n",
" pad_token_id=50256,\n",
" allowed_max_length=None,\n",
" mask_prompt_tokens=True,\n",
" device=\"cpu\"\n",
"):\n",
" # Initialize lists to hold batch data\n",
" batch_data = {\n",
" \"prompt\": [],\n",
" \"chosen\": [],\n",
" \"rejected\": [],\n",
" \"rejected_mask\": [],\n",
" \"chosen_mask\": []\n",
"\n",
" }\n",
"\n",
" # Determine the longest sequence to set a common padding length\n",
" max_length_common = 0\n",
" if batch:\n",
" for key in [\"chosen\", \"rejected\"]:\n",
" current_max = max(len(item[key])+1 for item in batch)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(\"Device:\", device)\n",
"\n",
"customized_collate_fn = partial(\n",
" custom_collate_fn,\n",
" device=device, # Put the data directly on a GPU if available\n",
" mask_prompt_tokens=True, # This is optional\n",
" allowed_max_length=1024 # The supported context length of the model\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5d29e996-e267-4348-bc1d-4ac6b725cf6a",
"metadata": {
"id": "5d29e996-e267-4348-bc1d-4ac6b725cf6a"
},
"source": [
"- Now, let's see the `customized_collate_fn` in action and apply it to some sample data from our preference dataset; for this, we take the first two entries:"
"{'instruction': 'Evaluate the following phrase by transforming it into the '\n",
" 'spelling given.',\n",
" 'input': 'freind --> friend',\n",
" 'output': 'The spelling of the given phrase \"freind\" is incorrect, the '\n",
" 'correct spelling is \"friend\".',\n",
" 'rejected': 'The spelling of the given phrase \"freind\" is flat out wrong, get '\n",
" 'it together, the correct spelling is \"friend\".',\n",
" 'chosen': 'The spelling of the given phrase \"freind\" is incorrect, the '\n",
" 'correct spelling is \"friend\".'}\n",
"\n",
"{'instruction': 'Edit the following sentence for grammar.',\n",
" 'input': 'He go to the park every day.',\n",
" 'output': 'He goes to the park every day.',\n",
" 'rejected': 'He goes to the stupid park every single day.',\n",
" 'chosen': 'He goes to the park every day.'}\n"
]
}
],
"source": [
"example_data = data[:2]\n",
"\n",
"for i in example_data:\n",
" print()\n",
" pprint.pp(i)"
]
},
{
"cell_type": "markdown",
"id": "8f1436cc-fbe5-4581-89d8-1992b5f04042",
"metadata": {
"id": "8f1436cc-fbe5-4581-89d8-1992b5f04042"
},
"source": [
"- Next, let's instantiate an `example_dataset` and use a PyTorch `DataLoader` to create an `example_dataloader` that mimics the data loader we will use for the model training later:"
"- The prompts are a list of tensors, where each tensor contains the token IDs for a given example; since we selected a batch size of 2, we have two lists of token ID tensors here:"
"- We don't really need the responses for training; what we need to feed to the model during training are the `\"chosen\"` and `\"rejected\"` entries\n",
"- The `\"chosen\"` and `\"rejected\"` response entries are padded so that we can stack them as tensors; similar to the prompts, these response texts are encoded into token IDs:"
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"Evaluate the following phrase by transforming it into the spelling given.\n",
"\n",
"### Input:\n",
"freind --> friend\n",
"\n",
"### Response:\n",
"The spelling of the given phrase \"freind\" is incorrect, the correct spelling is \"friend\".<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>\n"
]
}
],
"source": [
"text = decode_tokens_from_batch(\n",
" token_ids=batch[\"chosen\"][0],\n",
" tokenizer=tokenizer,\n",
")\n",
"print(text)"
]
},
{
"cell_type": "markdown",
"id": "ac9fbdbd-1cff-401f-8e6c-cd98c134c0f2",
"metadata": {
"id": "ac9fbdbd-1cff-401f-8e6c-cd98c134c0f2"
},
"source": [
"- As we can see above, similar to instruction finetuning, the response that is passed to the model during training also contains the input prompt\n",
"- Also note that we included `<|endoftext|>` tokens as padding tokens, which are necessary so that we can extend the responses to a similar length to stack them as a batch\n",
"- Don't worry; the `<|endoftext|>` tokens will be ignored in the loss later so that they won't affect the training outcome\n",
"- Let's now also inspect the corresponding rejected response:"
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"Evaluate the following phrase by transforming it into the spelling given.\n",
"\n",
"### Input:\n",
"freind --> friend\n",
"\n",
"### Response:\n",
"The spelling of the given phrase \"freind\" is flat out wrong, get it together, the correct spelling is \"friend\".<|endoftext|>\n"
]
}
],
"source": [
"text = decode_tokens_from_batch(\n",
" token_ids=batch[\"rejected\"][0],\n",
" tokenizer=tokenizer,\n",
")\n",
"print(text)"
]
},
{
"cell_type": "markdown",
"id": "715dc968-aa64-4388-b577-7c295831bdcf",
"metadata": {
"id": "715dc968-aa64-4388-b577-7c295831bdcf"
},
"source": [
"- In this case, as we can see above, the rejected response is a more impolite version of the chosen response (we don't want the model to generate impolite responses)\n",
"- Lastly, let's talk about the data masks: if you took a closer look at our custom collate function we implemented above, we created a `\"chosen_mask\"` and a `\"rejected_mask\"` for each dataset entry\n",
"- The masks have the same shape as the response entries, as shown below for the `\"chosen\"` entry:"
"- The `True` values denote token IDs that correspond to the actual response\n",
"- the `False` tokens correspond to token IDs that correspond to either prompt tokens (if we set `mask_prompt_tokens=True` in the `customized_collate_fn` function, which we previously did) or padding tokens\n",
"- Hence, we can use the mask as a selection mask to select only the token IDs that correspond to the response, that is, stripping all prompt and padding tokens, as we can see below:"
"- Each row shows the shape of the `\"chosen\"` and `\"rejected\"` entries in each batch\n",
"- Since we applied padding on a batch-by-batch basis, each row has a different shape\n",
"- This is for efficiency reasons because it would be inefficient to pad all samples to the longest sample in the whole dataset"
]
},
{
"cell_type": "markdown",
"id": "29cb0543-1142-4374-8825-3384e20c6ac0",
"metadata": {
"id": "29cb0543-1142-4374-8825-3384e20c6ac0"
},
"source": [
" \n",
"# 3) Loading a finetuned LLM for DPO alignment"
]
},
{
"cell_type": "markdown",
"id": "22b08881-b769-4b26-8153-5ec0e8573ed2",
"metadata": {
"id": "22b08881-b769-4b26-8153-5ec0e8573ed2"
},
"source": [
"- LLM alignment steps, such as RLHF or DPO, assume that we already have an instruction-finetuned model\n",
"- This section contains minimal code to load the model that was instruction finetuned and saved in chapter 7 (via [../01_main-chapter-code/ch07.ipynb](../01_main-chapter-code/ch07.ipynb))\n",
"- Make sure you run the chapter 7 code first to create the instruction-finetuned model before you proceed\n",
"- The code below will copy the instruction-finetuned model into the current directory:"
"- Before training the loaded model with DPO, let's make sure that the finetuned model was saved and loaded correctly by trying it out on some sample data:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "4357aec5-0db2-4d73-b37b-539cd8fa80a3",
"metadata": {
"id": "4357aec5-0db2-4d73-b37b-539cd8fa80a3"
},
"outputs": [],
"source": [
"prompt = \"\"\"Below is an instruction that describes a task. Write a response\n",
"that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"Convert the active sentence to passive: 'The chef cooks the meal every day.'\n",
"- As we can see above, the model gives a reasonable and correct response\n",
"- As explained in chapter 7, in practice, we would clean up the response to only return the response text with the prompt and prompt style removed (similar to what you are familiar with from ChatGPT, for example):"
"- Now, we are almost ready to get to the DPO part\n",
"- As mentioned at the beginning of this notebook, DPO works with two LLMs: a policy model (the LLM that we want to optimize) and a reference model (the original model that we keep unchanged)\n",
"- Below, we rename the `model` as `policy_model` and instantiate a second instance of the model we refer to as the `reference_model`"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "5d88cc3a-312e-4b29-bc6d-de8354c1eb9f",
"metadata": {
"id": "5d88cc3a-312e-4b29-bc6d-de8354c1eb9f"
},
"outputs": [],
"source": [
"policy_model = model\n",
"\n",
"reference_model = GPTModel(BASE_CONFIG)\n",
"reference_model.load_state_dict(\n",
" torch.load(\n",
" \"gpt2-medium355M-sft.pth\",\n",
" map_location=torch.device(\"cpu\"),\n",
" weights_only=True\n",
" )\n",
")\n",
"reference_model.eval()\n",
"\n",
"policy_model.to(device)\n",
"reference_model.to(device);"
]
},
{
"cell_type": "markdown",
"id": "9c6c1469-0038-4914-8aa5-15b1f81877cc",
"metadata": {
"id": "9c6c1469-0038-4914-8aa5-15b1f81877cc"
},
"source": [
" \n",
"# 4) Coding the DPO Loss Function"
]
},
{
"cell_type": "markdown",
"id": "75dbe60c-e4ce-413e-beec-22eff0237d11",
"metadata": {
"id": "75dbe60c-e4ce-413e-beec-22eff0237d11"
},
"source": [
"- After we took care of the model loading and dataset preparation in the previous sections, we can now get to the fun part and code the DPO loss\n",
"- Note that the DPO loss code below is based on the method proposed in the [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) paper\n",
"- For reference, the core DPO equation is shown again below:\n",
" - \"expected value\" $\\mathbb{E}$ is statistics jargon and stands for the average or mean value of the random variable (the expression inside the brackets)\n",
" - The $\\pi_{\\theta}$ variable is the so-called policy (a term borrowed from reinforcement learning) and represents the LLM we want to optimize; $\\pi_{ref}$ is a reference LLM, which is typically the original LLM before optimization (at the beginning of the training, $\\pi_{\\theta}$ and $\\pi_{ref}$ are typically the same)\n",
" - $\\beta$ is a hyperparameter to control the divergence between the $\\pi_{\\theta}$ and the reference model; increasing $\\beta$ increases the impact of the difference between\n",
"$\\pi_{\\theta}$ and $\\pi_{ref}$ in terms of their log probabilities on the overall loss function, thereby increasing the divergence between the two models\n",
"- In code, we can implement the DPO loss as follows:"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "38CsrrwJIZiV",
"metadata": {
"id": "38CsrrwJIZiV"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"def compute_dpo_loss(\n",
" model_chosen_logprobs,\n",
" model_rejected_logprobs,\n",
" reference_chosen_logprobs,\n",
" reference_rejected_logprobs,\n",
" beta=0.1,\n",
" ):\n",
" \"\"\"Compute the DPO loss for a batch of policy and reference model log probabilities.\n",
"\n",
" Args:\n",
" policy_chosen_logprobs: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)\n",
" policy_rejected_logprobs: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)\n",
" reference_chosen_logprobs: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)\n",
" reference_rejected_logprobs: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)\n",
" beta: Temperature parameter for the DPO loss; typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.\n",
" label_smoothing: conservativeness for DPO loss.\n",
"\n",
" Returns:\n",
" A tuple of three tensors: (loss, chosen_rewards, rejected_rewards).\n",
"- If you are familiar with logarithms, note that we have the general relationship $\\log\\left(\\frac{a}{b}\\right) = \\log a - \\log b$, which we applied in the code above\n",
"- Keeping this in mind, let's go through some of the steps (we will calculate the `logprobs` using a separate function later)\n",
"- These lines above calculate the difference in log probabilities (logits) for the chosen and rejected samples for both the policy model and the reference model (this is due to $\\log\\left(\\frac{a}{b}\\right) = \\log a - \\log b$):\n",
"- Next, the code `logits = model_logratios - reference_logratios` computes the difference between the model's log ratios and the reference model's log ratios, i.e., \n",
"- Finally, `losses = -F.logsigmoid(beta * logits)` calculates the loss using the log-sigmoid function; in the original equation, the term inside the expectation is \n",
"- Above, we assumed that the log probabilities were already computed; let's now define a `compute_logprobs` function that we can use to compute these log probabilities that were passed into the `compute_dpo_loss` function above, that is, the values $\\pi_\\theta (y_w \\mid x)$, ${\\pi_\\theta (y_l \\mid x)}$, and so forth:"
"- Note that this function above might look a bit intimidating at first due to the `torch.gather` function, but it's pretty similar to what happens under the hood in PyTorch's `cross_entropy` function\n",
"- So, above, we can see that the two implementations are equivalent, but let's narrow down a bit further to the `torch.gather` mechanics\n",
"- Consider the following two tensors:"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "508db6ba-cc40-479f-a996-2250cf862388",
"metadata": {
"id": "508db6ba-cc40-479f-a996-2250cf862388"
},
"outputs": [],
"source": [
"t = torch.tensor(\n",
" [[1., 2.,],\n",
" [3., 4.]]\n",
")\n",
"\n",
"m = torch.tensor(\n",
" [[1, 1],\n",
" [0, 1]]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "821cbf45-8fbb-47b7-bae8-6c3271e36979",
"metadata": {
"id": "821cbf45-8fbb-47b7-bae8-6c3271e36979"
},
"source": [
"- Above, `t` is a tensor we want to select from, and `m` is a mask to specify how we want to select\n",
" - For instance, since `m` contains `[1, 1]` n the first row, it will select two times the value of `t` in index position `1`, which is the value 2.\n",
" - The second row of `m`, `[0, 1]`, selects index positions 0 and 1 in the second row or `t`, which are `3.` and `4.`"
"- In other words, `torch.gather` is a selection function\n",
"- When we computed the loss earlier, we used it to retrieve the log probabilities corresponding to the correct token in the 50,256-token vocabulary\n",
"- The \"correct\" tokens are the tokens given in the response entry"
]
},
{
"cell_type": "markdown",
"id": "d5d10a43-ee5b-47ed-9d55-ddd96e66cf0b",
"metadata": {
"id": "d5d10a43-ee5b-47ed-9d55-ddd96e66cf0b"
},
"source": [
"- Regarding the `compute_logprobs` function above, we use `torch.gather` here because it gives us a bit more control than `cross_entropy`, but is, in essence, a similar idea\n",
"- The `selection_mask` we use there is to optionally ignore prompt and padding tokens\n",
"- We can then use the `compute_logprobs` function as follows to compute the inputs for the `compute_dpo_loss` loss function"
"- Why a specified `num_batches`? That's purely for efficiency reasons (because calculating the loss on the whole dataset each time would slow down the training significantly"
]
},
{
"cell_type": "markdown",
"id": "2cca95b7-18fe-4076-9138-f70f21607b8c",
"metadata": {
"id": "2cca95b7-18fe-4076-9138-f70f21607b8c"
},
"source": [
"- Lastly, we define a convenience function for our training function later; this `evaluate_dpo_loss_loader` function computes the DPO loss and rewards for both the training and validation loader for logging purposes:"
"- In this section, we covered a lot of ground as a brief recap:\n",
" - The flow is: compute `logits` via the models $\\rightarrow$ `compute_logprobs` from logits $\\rightarrow$ compute `compute_dpo_loss` from log probabilities\n",
" - we have the `compute_dpo_loss_batch` function that facilitates the process above\n",
" - the `compute_dpo_loss_loader` utility function applies the `compute_dpo_loss_batch` function to a data loader\n",
" - the `evaluate_dpo_loss_loader` function applies the `compute_dpo_loss_batch` to both the training and validation set data loaders for logging purposes"
]
},
{
"cell_type": "markdown",
"id": "cb8a8f18-536e-4d83-a0d0-ac518a85f157",
"metadata": {
"id": "cb8a8f18-536e-4d83-a0d0-ac518a85f157"
},
"source": [
" \n",
"# 5) Training the model"
]
},
{
"cell_type": "markdown",
"id": "4b11d63d-3ddc-4070-9b2b-5ca0edb08d0c",
"metadata": {
"id": "4b11d63d-3ddc-4070-9b2b-5ca0edb08d0c"
},
"source": [
"- After setting up the DPO loss functions in the previous section, we can now finally train the model\n",
"- Note that this training function is the same one we used for pretraining and instruction finetuning, with minor differences:\n",
" - we swap the cross entropy loss with our new DPO loss function\n",
" - we also track the rewards and reward margins, which are commonly used in RLHF and DPO contexts to track the training progress\n"
"- Note that the goal of DPO is to induce slight style changes; this means we want the model to generate similar but slightly more polite responses\n",
"- Before we execute the following code cell that starts the training, here are a few notes about some of the settings:\n",
" - we are only passing the parameters of the policy model into the `AdamW` optimizer; that's the model we want to optimize (we don't want to modify the reference model)\n",
" - we only train for 1 epoch; that's because DPO is very prone to collapse (the loss might improve, but the model will start generating nonsensical texts)\n",
" - in DPO, it's best to use a very small learning rate\n",
" - the beta value can be increased from 0.1 to 0.5 to reduce the effect of DPO (we use 0.1 here to make the results more noticeable)\n",
" - The training takes about 2 minutes on an A100 GPU, but it can also be trained in 4 minutes on a smaller L4 GPU; training on a M3 MacBook Air takes about 30 minutes"
"Ep 1 (Step 000000): Train loss 0.692, Val loss 0.693, Train reward margins 0.019, Val reward margins 0.009\n",
"Ep 1 (Step 000005): Train loss 0.690, Val loss 0.691, Train reward margins 0.070, Val reward margins 0.052\n",
"Ep 1 (Step 000010): Train loss 0.687, Val loss 0.688, Train reward margins 0.126, Val reward margins 0.108\n",
"Ep 1 (Step 000015): Train loss 0.676, Val loss 0.685, Train reward margins 0.362, Val reward margins 0.173\n",
"Ep 1 (Step 000020): Train loss 0.676, Val loss 0.680, Train reward margins 0.351, Val reward margins 0.264\n",
"Ep 1 (Step 000025): Train loss 0.666, Val loss 0.676, Train reward margins 0.564, Val reward margins 0.359\n",
"Ep 1 (Step 000030): Train loss 0.672, Val loss 0.672, Train reward margins 0.456, Val reward margins 0.441\n",
"Ep 1 (Step 000035): Train loss 0.663, Val loss 0.669, Train reward margins 0.658, Val reward margins 0.511\n",
"Ep 1 (Step 000040): Train loss 0.666, Val loss 0.666, Train reward margins 0.597, Val reward margins 0.574\n",
"Ep 1 (Step 000045): Train loss 0.648, Val loss 0.662, Train reward margins 0.982, Val reward margins 0.660\n",
"Ep 1 (Step 000050): Train loss 0.648, Val loss 0.659, Train reward margins 0.993, Val reward margins 0.734\n",
"Ep 1 (Step 000055): Train loss 0.647, Val loss 0.656, Train reward margins 1.014, Val reward margins 0.799\n",
"Ep 1 (Step 000060): Train loss 0.652, Val loss 0.653, Train reward margins 0.893, Val reward margins 0.870\n",
"Ep 1 (Step 000065): Train loss 0.631, Val loss 0.650, Train reward margins 1.361, Val reward margins 0.948\n",
"Ep 1 (Step 000070): Train loss 0.618, Val loss 0.646, Train reward margins 1.699, Val reward margins 1.038\n",
"Ep 1 (Step 000075): Train loss 0.617, Val loss 0.642, Train reward margins 1.733, Val reward margins 1.121\n",
"Ep 1 (Step 000080): Train loss 0.592, Val loss 0.639, Train reward margins 2.333, Val reward margins 1.194\n",
"Ep 1 (Step 000085): Train loss 0.610, Val loss 0.636, Train reward margins 1.907, Val reward margins 1.275\n",
"Ep 1 (Step 000090): Train loss 0.650, Val loss 0.633, Train reward margins 0.964, Val reward margins 1.353\n",
"Ep 1 (Step 000095): Train loss 0.607, Val loss 0.630, Train reward margins 1.962, Val reward margins 1.423\n",
"Ep 1 (Step 000100): Train loss 0.600, Val loss 0.627, Train reward margins 2.127, Val reward margins 1.500\n",
"Ep 1 (Step 000105): Train loss 0.590, Val loss 0.624, Train reward margins 2.458, Val reward margins 1.564\n",
"Ep 1 (Step 000110): Train loss 0.607, Val loss 0.622, Train reward margins 1.976, Val reward margins 1.621\n",
"Ep 1 (Step 000115): Train loss 0.621, Val loss 0.620, Train reward margins 1.605, Val reward margins 1.682\n",
"Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Rewrite the sentence using a metaphor. ### Input: The book is very interesting. ### Response: The book is a treat.<|endoftext|>The following is an instruction that describes a task. Write a response that appropriately completes the request. ### Input: The assignment was written by the student. ### Response\n",
"- As we can see above, the loss continues to improve, which is a good sign\n",
"- Based on the downward slope, one might be tempted to train the model a bit further (and readers are encouraged to try this), but not that DPO is prone to collapse, where the model may start generating nonsensical responses\n",
"- Next, let's take a look at the reward margins:"
"- As we can see based on the reference model and policy model responses above, the optimized model (i.e., the policy model) indeed slightly changed its style compared to the original model (i.e., reference model)\n",
"- For instance, `\"Dance\" can be classified as a verb.` changed to `The input string \"Dance\" could be classified as a verb.` which is a slightly more polite response (the use of \"could\" instead of \"can\" makes the statement sound less assertive and more tentative)"