Skip to content

Commit 4dc551d

Browse files
committed
Created using Colaboratory
1 parent 87306ca commit 4dc551d

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

Fine_tune_a_Mistral_7b_model_with_DPO.ipynb

+6-12
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"provenance": [],
77
"machine_shape": "hm",
88
"gpuType": "A100",
9-
"authorship_tag": "ABX9TyOJJCuqxZQnS1q+Fvz5+URG",
9+
"authorship_tag": "ABX9TyNuIN7/ICiXCX5xELzN1Y3R",
1010
"include_colab_link": true
1111
},
1212
"kernelspec": {
@@ -380,6 +380,8 @@
380380
"source": [
381381
"# Fine-tune a Mistral-7b model with DPO\n",
382382
"\n",
383+
"> 🗣️ [Large Language Model Course](https://github.com/mlabonne/llm-course)\n",
384+
"\n",
383385
"❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne)."
384386
],
385387
"metadata": {
@@ -469,10 +471,10 @@
469471
" prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)\n",
470472
"\n",
471473
" # Format chosen answer\n",
472-
" chosen = example['chatgpt'] + \"<|im_end|>\\n\"\n",
474+
" chosen = example['chosen'] + \"<|im_end|>\\n\"\n",
473475
"\n",
474476
" # Format rejected answer\n",
475-
" rejected = example['llama2-13b-chat'] + \"<|im_end|>\\n\"\n",
477+
" rejected = example['rejected'] + \"<|im_end|>\\n\"\n",
476478
"\n",
477479
" return {\n",
478480
" \"prompt\": system + prompt,\n",
@@ -561,13 +563,6 @@
561563
")\n",
562564
"model.config.use_cache = False\n",
563565
"\n",
564-
"# Reference model\n",
565-
"ref_model = AutoModelForCausalLM.from_pretrained(\n",
566-
" model_name,\n",
567-
" torch_dtype=torch.float16,\n",
568-
" load_in_4bit=True\n",
569-
")\n",
570-
"\n",
571566
"# Training arguments\n",
572567
"training_args = TrainingArguments(\n",
573568
" per_device_train_batch_size=4,\n",
@@ -588,7 +583,6 @@
588583
"# Create DPO trainer\n",
589584
"dpo_trainer = DPOTrainer(\n",
590585
" model,\n",
591-
" ref_model,\n",
592586
" args=training_args,\n",
593587
" train_dataset=dataset,\n",
594588
" tokenizer=tokenizer,\n",
@@ -624,7 +618,7 @@
624618
"tokenizer.save_pretrained(\"final_checkpoint\")\n",
625619
"\n",
626620
"# Flush memory\n",
627-
"del dpo_trainer, model, ref_model\n",
621+
"del dpo_trainer, model\n",
628622
"gc.collect()\n",
629623
"torch.cuda.empty_cache()\n",
630624
"\n",

0 commit comments

Comments
 (0)