`ref_model` not needed in `Fine_tune_a_Mistral_7b_model_with_DPO.ipynb`
Hi here @mlabonne! Congratulations on your awesome work with this course 🤝🏻
After going through Fine_tune_a_Mistral_7b_model_with_DPO.ipynb I realised that there's no need to define the ref_model required by DPO, since when fine-tuning using LoRA, the reference model is not required, as the one without the adapters will be used to compute the logprobs, so you can remove the ref_model and the result will still be the same, but using even less resources.
Finally, as a tip, when using the DPOTrainer for full fine-tunes you can also specify precompute_ref_log_probs to compute those in advance before the actual fine-tune starts, so that the ref_model is not needed either.
Hey @alvarobartt, thanks a lot for the hints. I am using the above notebook and your suggestion solved my memory issue on google colab.
Yep, if you try to run DPOTrainer when passing the ref model, you get the runtime error below, to fix you can just comment out ref_model in DPOTrainer (and cleanup the declaration of ref_model).
Thanks @mlabonne for this super notebook which gets me started with going beyond SFT with first DPO tune,
File /usr/local/lib/python3.10/dist-packages/trl/trainer/dpo_trainer.py:217, in DPOTrainer.__init__(self, model, ref_model, beta, label_smoothing, loss_type, args, data_collator, label_pad_token_id, padding_value, truncation_mode, train_dataset, eval_dataset, tokenizer, model_init, callbacks, optimizers, preprocess_logits_for_metrics, max_length, max_prompt_length, max_target_length, peft_config, is_encoder_decoder, disable_dropout, generate_during_eval, compute_metrics, precompute_ref_log_probs, dataset_num_proc, model_init_kwargs, ref_model_init_kwargs, model_adapter_name, ref_adapter_name, reference_free)
214 model = model.merge_and_unload()
216 if ref_model is not None:
--> 217 raise ValueError(
218 "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
219 " model. Please pass `ref_model=None` in case you want to train PEFT adapters."
220 )
222 if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
223 _support_gc_kwargs = hasattr(
224 args, "gradient_checkpointing_kwargs"
225 ) and "gradient_checkpointing_kwargs" in list(
226 inspect.signature(prepare_model_for_kbit_training).parameters
227 )
ValueError: You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference model. Please pass `ref_model=None` in case you want to train PEFT adapters.
Thanks @alvarobartt for opening this issue! I faced the same problem and following your suggestion, solved it. I removed the declaration for ref_model as @corticalstack suggested and I further removed the ref_model argument in the DPOTrainer. Has anyone opened a PR for this fix? If not, I am happy to do so!
I updated the notebook and removed the ref_model. Please let me know if it broke something, I couldn't test it.