llm-course icon indicating copy to clipboard operation
llm-course copied to clipboard

`ref_model` not needed in `Fine_tune_a_Mistral_7b_model_with_DPO.ipynb`

Open alvarobartt opened this issue 1 year ago • 4 comments

Hi here @mlabonne! Congratulations on your awesome work with this course 🤝🏻

After going through Fine_tune_a_Mistral_7b_model_with_DPO.ipynb I realised that there's no need to define the ref_model required by DPO, since when fine-tuning using LoRA, the reference model is not required, as the one without the adapters will be used to compute the logprobs, so you can remove the ref_model and the result will still be the same, but using even less resources.

Finally, as a tip, when using the DPOTrainer for full fine-tunes you can also specify precompute_ref_log_probs to compute those in advance before the actual fine-tune starts, so that the ref_model is not needed either.

alvarobartt avatar Jan 31 '24 13:01 alvarobartt

Hey @alvarobartt, thanks a lot for the hints. I am using the above notebook and your suggestion solved my memory issue on google colab.

AzizCode92 avatar Feb 05 '24 20:02 AzizCode92

Yep, if you try to run DPOTrainer when passing the ref model, you get the runtime error below, to fix you can just comment out ref_model in DPOTrainer (and cleanup the declaration of ref_model).

Thanks @mlabonne for this super notebook which gets me started with going beyond SFT with first DPO tune,

File /usr/local/lib/python3.10/dist-packages/trl/trainer/dpo_trainer.py:217, in DPOTrainer.__init__(self, model, ref_model, beta, label_smoothing, loss_type, args, data_collator, label_pad_token_id, padding_value, truncation_mode, train_dataset, eval_dataset, tokenizer, model_init, callbacks, optimizers, preprocess_logits_for_metrics, max_length, max_prompt_length, max_target_length, peft_config, is_encoder_decoder, disable_dropout, generate_during_eval, compute_metrics, precompute_ref_log_probs, dataset_num_proc, model_init_kwargs, ref_model_init_kwargs, model_adapter_name, ref_adapter_name, reference_free)
    214     model = model.merge_and_unload()
    216 if ref_model is not None:
--> 217     raise ValueError(
    218         "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
    219         " model. Please pass `ref_model=None` in case you want to train PEFT adapters."
    220     )
    222 if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
    223     _support_gc_kwargs = hasattr(
    224         args, "gradient_checkpointing_kwargs"
    225     ) and "gradient_checkpointing_kwargs" in list(
    226         inspect.signature(prepare_model_for_kbit_training).parameters
    227     )

ValueError: You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference model. Please pass `ref_model=None` in case you want to train PEFT adapters.

corticalstack avatar Feb 18 '24 20:02 corticalstack

Thanks @alvarobartt for opening this issue! I faced the same problem and following your suggestion, solved it. I removed the declaration for ref_model as @corticalstack suggested and I further removed the ref_model argument in the DPOTrainer. Has anyone opened a PR for this fix? If not, I am happy to do so!

RGaonkar avatar Mar 22 '24 23:03 RGaonkar

I updated the notebook and removed the ref_model. Please let me know if it broke something, I couldn't test it.

mlabonne avatar Mar 24 '24 11:03 mlabonne