Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Chunked DPO

Open cyr0930 opened this issue 6 months ago • 0 comments

Summary

Try to fix https://github.com/linkedin/Liger-Kernel/issues/439.

As above issue addressed, chunk hidden state across batch-dimension has restrictive benefits. Therefore I try to chunk hidden state across (batch*seq_len)-dimension.

As it requires non-trivial online loss computation, we cannot use fusing forward-backward technique in this case. However, memory footprint issue still can be addressed by slicing hidden state into small chunks. This is because we can do backward-step chunk by chunk instead of doing it all at once which results in high spike, although materializing all logits is inevitable.

I'm not sure the implementation of this PR is perfect for now, but just want to check this idea is valid and aligned with the spirit of liger-kernel. Any feedback would be great. Thanks.

Testing Done

I haven't run the tests yet, because I just want to check this concept is okay to be accepted.

  • Hardware Type: <A100-SXM4-80GB>
  • [ ] run make test to ensure correctness
  • [ ] run make checkstyle to ensure code style
  • [ ] run make test-convergence to ensure convergence

cyr0930 avatar May 21 '25 07:05 cyr0930