Minor DPO fixes (#617)

* minor dpo fixes

* Update dpo-from-scratch.ipynb

metadata diff
This commit is contained in:
casinca 2025-04-16 19:56:49 +02:00 committed by GitHub
parent f3d1566c2e
commit 1b242d01a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1876,7 +1876,6 @@
" reference_chosen_logprobs: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)\n",
" reference_rejected_logprobs: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)\n",
" beta: Temperature parameter for the DPO loss; typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.\n",
" label_smoothing: conservativeness for DPO loss.\n",
"\n",
" Returns:\n",
" A tuple of three tensors: (loss, chosen_rewards, rejected_rewards).\n",
@ -1998,7 +1997,7 @@
" selected_log_probs = selected_log_probs * mask\n",
"\n",
" # Calculate the average log probability excluding padding tokens\n",
" # This averages over the tokens, so the shape is (batch_size, num_tokens)\n",
" # This averages over the tokens, so the shape is (batch_size,)\n",
" avg_log_prob = selected_log_probs.sum(-1) / mask.sum(-1)\n",
"\n",
" return avg_log_prob\n",
@ -2439,7 +2438,7 @@
" for epoch in range(num_epochs):\n",
" policy_model.train() # Set model to training mode\n",
"\n",
" for batch_idx, batch in enumerate(train_loader):\n",
" for batch in train_loader:\n",
"\n",
" optimizer.zero_grad() # Reset loss gradients from previous batch iteration\n",
"\n",
@ -3113,7 +3112,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@ -3127,7 +3126,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.12.6"
}
},
"nbformat": 4,