torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Implement CPO (Contrastive Preference Optimization)

Open SalmanMohammadi opened this issue 6 months ago • 0 comments

CPO seems like an interesting direct-preference-optimisation-style loss function which, similar to SimPO, also eliminates the need for a reference model. There's also a reference implementation for the loss function in TRL.

  1. Add a cpo_loss.py in our loss module, using SimPO as an example. Implement the loss, ideally citing a reference implementation. Document the loss in api_ref_modules.rst.
  2. Add a test for cpo_loss, using test_dpo_loss, or test_simpo_loss as examples to ensure the math works as expected for dummy inputs.
  3. Add the loss in the DPO recipe docs. It should be a reference free loss.
  4. Complete a (small?) training run with the torchtune implementation, and also a reference implementation (like TRL), and show the training behaviour is roughly identical.
  5. $$$.

See #645, #1223 for priorart and discussion.

SalmanMohammadi avatar Aug 08 '24 19:08 SalmanMohammadi