Skip to content

Commit 1b242d0

Browse files
authored
Minor DPO fixes (#617)
* minor dpo fixes * Update dpo-from-scratch.ipynb metadata diff
1 parent f3d1566 commit 1b242d0

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb

+4-5
Original file line numberDiff line numberDiff line change
@@ -1876,7 +1876,6 @@
18761876
" reference_chosen_logprobs: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)\n",
18771877
" reference_rejected_logprobs: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)\n",
18781878
" beta: Temperature parameter for the DPO loss; typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.\n",
1879-
" label_smoothing: conservativeness for DPO loss.\n",
18801879
"\n",
18811880
" Returns:\n",
18821881
" A tuple of three tensors: (loss, chosen_rewards, rejected_rewards).\n",
@@ -1998,7 +1997,7 @@
19981997
" selected_log_probs = selected_log_probs * mask\n",
19991998
"\n",
20001999
" # Calculate the average log probability excluding padding tokens\n",
2001-
" # This averages over the tokens, so the shape is (batch_size, num_tokens)\n",
2000+
" # This averages over the tokens, so the shape is (batch_size,)\n",
20022001
" avg_log_prob = selected_log_probs.sum(-1) / mask.sum(-1)\n",
20032002
"\n",
20042003
" return avg_log_prob\n",
@@ -2439,7 +2438,7 @@
24392438
" for epoch in range(num_epochs):\n",
24402439
" policy_model.train() # Set model to training mode\n",
24412440
"\n",
2442-
" for batch_idx, batch in enumerate(train_loader):\n",
2441+
" for batch in train_loader:\n",
24432442
"\n",
24442443
" optimizer.zero_grad() # Reset loss gradients from previous batch iteration\n",
24452444
"\n",
@@ -3113,7 +3112,7 @@
31133112
"provenance": []
31143113
},
31153114
"kernelspec": {
3116-
"display_name": "Python 3 (ipykernel)",
3115+
"display_name": ".venv",
31173116
"language": "python",
31183117
"name": "python3"
31193118
},
@@ -3127,7 +3126,7 @@
31273126
"name": "python",
31283127
"nbconvert_exporter": "python",
31293128
"pygments_lexer": "ipython3",
3130-
"version": "3.10.16"
3129+
"version": "3.12.6"
31313130
}
31323131
},
31333132
"nbformat": 4,

0 commit comments

Comments
 (0)