|
6 | 6 | "provenance": [],
|
7 | 7 | "machine_shape": "hm",
|
8 | 8 | "gpuType": "A100",
|
9 |
| - "authorship_tag": "ABX9TyOJJCuqxZQnS1q+Fvz5+URG", |
| 9 | + "authorship_tag": "ABX9TyNuIN7/ICiXCX5xELzN1Y3R", |
10 | 10 | "include_colab_link": true
|
11 | 11 | },
|
12 | 12 | "kernelspec": {
|
|
380 | 380 | "source": [
|
381 | 381 | "# Fine-tune a Mistral-7b model with DPO\n",
|
382 | 382 | "\n",
|
| 383 | + "> 🗣️ [Large Language Model Course](https://github.com/mlabonne/llm-course)\n", |
| 384 | + "\n", |
383 | 385 | "❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne)."
|
384 | 386 | ],
|
385 | 387 | "metadata": {
|
|
469 | 471 | " prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)\n",
|
470 | 472 | "\n",
|
471 | 473 | " # Format chosen answer\n",
|
472 |
| - " chosen = example['chatgpt'] + \"<|im_end|>\\n\"\n", |
| 474 | + " chosen = example['chosen'] + \"<|im_end|>\\n\"\n", |
473 | 475 | "\n",
|
474 | 476 | " # Format rejected answer\n",
|
475 |
| - " rejected = example['llama2-13b-chat'] + \"<|im_end|>\\n\"\n", |
| 477 | + " rejected = example['rejected'] + \"<|im_end|>\\n\"\n", |
476 | 478 | "\n",
|
477 | 479 | " return {\n",
|
478 | 480 | " \"prompt\": system + prompt,\n",
|
|
561 | 563 | ")\n",
|
562 | 564 | "model.config.use_cache = False\n",
|
563 | 565 | "\n",
|
564 |
| - "# Reference model\n", |
565 |
| - "ref_model = AutoModelForCausalLM.from_pretrained(\n", |
566 |
| - " model_name,\n", |
567 |
| - " torch_dtype=torch.float16,\n", |
568 |
| - " load_in_4bit=True\n", |
569 |
| - ")\n", |
570 |
| - "\n", |
571 | 566 | "# Training arguments\n",
|
572 | 567 | "training_args = TrainingArguments(\n",
|
573 | 568 | " per_device_train_batch_size=4,\n",
|
|
588 | 583 | "# Create DPO trainer\n",
|
589 | 584 | "dpo_trainer = DPOTrainer(\n",
|
590 | 585 | " model,\n",
|
591 |
| - " ref_model,\n", |
592 | 586 | " args=training_args,\n",
|
593 | 587 | " train_dataset=dataset,\n",
|
594 | 588 | " tokenizer=tokenizer,\n",
|
|
624 | 618 | "tokenizer.save_pretrained(\"final_checkpoint\")\n",
|
625 | 619 | "\n",
|
626 | 620 | "# Flush memory\n",
|
627 |
| - "del dpo_trainer, model, ref_model\n", |
| 621 | + "del dpo_trainer, model\n", |
628 | 622 | "gc.collect()\n",
|
629 | 623 | "torch.cuda.empty_cache()\n",
|
630 | 624 | "\n",
|
|
0 commit comments