trl
trl copied to clipboard
Inquiry about the impact of gradient checkpointing on KL divergence estimation.
I am currently working on experiments of DPO and KTO Trainer on private dataset. I am considering using gradient checkpointing to reduce memory usage during backpropagation, but I am unsure of its impact on the KL divergence estimation. Could someone please provide insights into how gradient checkpointing might affect the estimation of KL divergence? Are there any potential issues or trade-offs that I should be aware of? Any help or guidance would be greatly appreciated.