trl icon indicating copy to clipboard operation
trl copied to clipboard

Add model merging callback

Open lewtun opened this issue 1 year ago • 6 comments

Feature request

Add a MergeModelCallback that merges the reference model with the current policy and optionally pushes the merged checkpoint to the Hub. This could be done on step/epoch end and/or the end of training. Implementation-wise, we could use Arcee's mergekit lib and include it as an optional dependency: https://github.com/arcee-ai/mergekit

Motivation

Various papers show that model merging can non-trivially improve performance, especially if the models belong to the same architecture:

  • https://arxiv.org/abs/2410.10801
  • https://arxiv.org/abs/2406.16768 (for reward models)

Your contribution

Open to the community!

lewtun avatar Oct 16 '24 12:10 lewtun

I'm interested in working on this!

coding-famer avatar Oct 17 '24 23:10 coding-famer

Nice! Thanks @coding-famer. Feel free to open a PR then and request any help if needed

qgallouedec avatar Oct 18 '24 13:10 qgallouedec

@lewtun After reading the paper, I noticed that the DPO checkpoints were combined with a different model rather than the reference model used in DPO training. So, I added an option in my PR to set an external model for merging instead of the reference model.

August-murr avatar Oct 25 '24 10:10 August-murr

Hi @August-murr , happy to see that you have already worked it out! However I noticed that your implementation only allows merge models in the disk after training, this could be done by user using mergekit directly after training. I think the thing here is to merge the model during the training steps/epochs?

coding-famer avatar Oct 25 '24 18:10 coding-famer

@coding-famer The callback has an optional parameter called merge_at_every_checkpoint, which merges the saved checkpoint at either every step or at the end of each epoch during training.

August-murr avatar Oct 25 '24 18:10 August-murr

@coding-famer The callback has an optional parameter called merge_at_every_checkpoint, which merges the saved checkpoint at either every step or at the end of each epoch during training.

Sounds great!

coding-famer avatar Oct 25 '24 19:10 coding-famer