improve latex rendering in dpo notebook

This commit is contained in:
rasbt 2024-08-04 09:19:54 -05:00
parent 2c6cdb497f
commit f302f5e8d5

View File

@ -1886,8 +1886,9 @@
"reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs\n", "reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs\n",
"```\n", "```\n",
"\n", "\n",
"- These lines above calculate the difference in log probabilities (logits) for the chosen and rejected samples for both the policy model and the reference model:\n", "- These lines above calculate the difference in log probabilities (logits) for the chosen and rejected samples for both the policy model and the reference model (this is due to $\\log\\left(\\frac{a}{b}\\right) = \\log a - \\log b$):\n",
"$\\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_\\theta (y_l \\mid x)} \\right) \\quad \\text{and} \\quad \\log \\left( \\frac{\\pi_{\\text{ref}}(y_w \\mid x)}{\\pi_{\\text{ref}}(y_l \\mid x)} \\right)$; this is due to $\\log\\left(\\frac{a}{b}\\right) = \\log a - \\log b$" "\n",
"$$\\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_\\theta (y_l \\mid x)} \\right) \\quad \\text{and} \\quad \\log \\left( \\frac{\\pi_{\\text{ref}}(y_w \\mid x)}{\\pi_{\\text{ref}}(y_l \\mid x)} \\right)$$"
] ]
}, },
{ {
@ -1897,8 +1898,10 @@
"id": "5458d217-e0ad-40a5-925c-507a8fcf5795" "id": "5458d217-e0ad-40a5-925c-507a8fcf5795"
}, },
"source": [ "source": [
"- Next, the code `logits = model_logratios - reference_logratios` computes the difference between the model's log ratios and the reference model's log ratios, i.e., $\\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n", "- Next, the code `logits = model_logratios - reference_logratios` computes the difference between the model's log ratios and the reference model's log ratios, i.e., \n",
"- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right)$\n" "\n",
"$$\\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n",
"- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right)$$\n"
] ]
}, },
{ {
@ -1908,8 +1911,10 @@
"id": "f18e3e36-f5f1-407f-b662-4c20a0ac0354" "id": "f18e3e36-f5f1-407f-b662-4c20a0ac0354"
}, },
"source": [ "source": [
"- Finally, `losses = -F.logsigmoid(beta * logits)` calculates the loss using the log-sigmoid function; in the original equation, the term inside the expectation is $\\log \\sigma \\left( \\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n", "- Finally, `losses = -F.logsigmoid(beta * logits)` calculates the loss using the log-sigmoid function; in the original equation, the term inside the expectation is \n",
"- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right) \\right)$" "\n",
"$$\\log \\sigma \\left( \\beta \\log \\left( \\frac{\\pi_\\theta (y_w \\mid x)}{\\pi_{\\text{ref}} (y_w \\mid x)} \\right)\n",
"- \\beta \\log \\left( \\frac{\\pi_\\theta (y_l \\mid x)}{\\pi_{\\text{ref}} (y_l \\mid x)} \\right) \\right)$$"
] ]
}, },
{ {
@ -3089,7 +3094,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.4" "version": "3.10.6"
} }
}, },
"nbformat": 4, "nbformat": 4,