|
1876 | 1876 | " reference_chosen_logprobs: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)\n",
|
1877 | 1877 | " reference_rejected_logprobs: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)\n",
|
1878 | 1878 | " beta: Temperature parameter for the DPO loss; typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.\n",
|
1879 |
| - " label_smoothing: conservativeness for DPO loss.\n", |
1880 | 1879 | "\n",
|
1881 | 1880 | " Returns:\n",
|
1882 | 1881 | " A tuple of three tensors: (loss, chosen_rewards, rejected_rewards).\n",
|
|
1998 | 1997 | " selected_log_probs = selected_log_probs * mask\n",
|
1999 | 1998 | "\n",
|
2000 | 1999 | " # Calculate the average log probability excluding padding tokens\n",
|
2001 |
| - " # This averages over the tokens, so the shape is (batch_size, num_tokens)\n", |
| 2000 | + " # This averages over the tokens, so the shape is (batch_size,)\n", |
2002 | 2001 | " avg_log_prob = selected_log_probs.sum(-1) / mask.sum(-1)\n",
|
2003 | 2002 | "\n",
|
2004 | 2003 | " return avg_log_prob\n",
|
|
2439 | 2438 | " for epoch in range(num_epochs):\n",
|
2440 | 2439 | " policy_model.train() # Set model to training mode\n",
|
2441 | 2440 | "\n",
|
2442 |
| - " for batch_idx, batch in enumerate(train_loader):\n", |
| 2441 | + " for batch in train_loader:\n", |
2443 | 2442 | "\n",
|
2444 | 2443 | " optimizer.zero_grad() # Reset loss gradients from previous batch iteration\n",
|
2445 | 2444 | "\n",
|
|
3113 | 3112 | "provenance": []
|
3114 | 3113 | },
|
3115 | 3114 | "kernelspec": {
|
3116 |
| - "display_name": "Python 3 (ipykernel)", |
| 3115 | + "display_name": ".venv", |
3117 | 3116 | "language": "python",
|
3118 | 3117 | "name": "python3"
|
3119 | 3118 | },
|
|
3127 | 3126 | "name": "python",
|
3128 | 3127 | "nbconvert_exporter": "python",
|
3129 | 3128 | "pygments_lexer": "ipython3",
|
3130 |
| - "version": "3.10.16" |
| 3129 | + "version": "3.12.6" |
3131 | 3130 | }
|
3132 | 3131 | },
|
3133 | 3132 | "nbformat": 4,
|
|
0 commit comments