torchtune
torchtune copied to clipboard
Implement CPO (Contrastive Preference Optimization)
CPO seems like an interesting direct-preference-optimisation-style loss function which, similar to SimPO, also eliminates the need for a reference model. There's also a reference implementation for the loss function in TRL.
- Add a
cpo_loss.py
in ourloss
module, usingSimPO
as an example. Implement the loss, ideally citing a reference implementation. Document the loss inapi_ref_modules.rst
. - Add a test for
cpo_loss
, usingtest_dpo_loss
, ortest_simpo_loss
as examples to ensure the math works as expected for dummy inputs. - Add the loss in the DPO recipe docs. It should be a reference free loss.
- Complete a (small?) training run with the torchtune implementation, and also a reference implementation (like TRL), and show the training behaviour is roughly identical.
- $$$.
See #645, #1223 for priorart and discussion.