improve latex rendering in dpo notebook

This commit is contained in:
rasbt 2024-08-04 09:19:54 -05:00
parent 2c6cdb497f
commit f302f5e8d5

View File

@ -1886,8 +1886,9 @@
"reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs\n",
"```\n",
"\n",
"- These lines above calculate the difference in log probabilities (logits) for the chosen and rejected samples for both the policy model and the reference model:\n",
"$\\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_\\theta (y_l \\mid x)} \\right) \\quad \\text{and} \\quad \\log \\left( \\frac{\\pi_{\\text{ref}}(y_w \\mid x)}{\\pi_{\\text{ref}}(y_l \\mid x)} \\right)$; this is due to $\\log\\left(\\frac{a}{b}\\right) = \\log a - \\log b$"
"- These lines above calculate the difference in log probabilities (logits) for the chosen and rejected samples for both the policy model and the reference model (this is due to $\\log\\left(\\frac{a}{b}\\right) = \\log a - \\log b$):\n",
"\n",
"$$\\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_\\theta (y_l \\mid x)} \\right) \\quad \\text{and} \\quad \\log \\left( \\frac{\\pi_{\\text{ref}}(y_w \\mid x)}{\\pi_{\\text{ref}}(y_l \\mid x)} \\right)$$"
]
},
{
@ -1897,8 +1898,10 @@
"id": "5458d217-e0ad-40a5-925c-507a8fcf5795"
},
"source": [
"- Next, the code `logits = model_logratios - reference_logratios` computes the difference between the model's log ratios and the reference model's log ratios, i.e., $\\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n",
"- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right)$\n"
"- Next, the code `logits = model_logratios - reference_logratios` computes the difference between the model's log ratios and the reference model's log ratios, i.e., \n",
"\n",
"$$\\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n",
"- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right)$$\n"
]
},
{
@ -1908,8 +1911,10 @@
"id": "f18e3e36-f5f1-407f-b662-4c20a0ac0354"
},
"source": [
"- Finally, `losses = -F.logsigmoid(beta * logits)` calculates the loss using the log-sigmoid function; in the original equation, the term inside the expectation is $\\log \\sigma \\left( \\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n",
"- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right) \\right)$"
"- Finally, `losses = -F.logsigmoid(beta * logits)` calculates the loss using the log-sigmoid function; in the original equation, the term inside the expectation is \n",
"\n",
"$$\\log \\sigma \\left( \\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n",
"- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right) \\right)$$"
]
},
{
@ -3089,7 +3094,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,