Liger-Kernel
Liger-Kernel copied to clipboard
Chunked DPO
Summary
Try to fix https://github.com/linkedin/Liger-Kernel/issues/439.
As above issue addressed, chunk hidden state across batch-dimension has restrictive benefits. Therefore I try to chunk hidden state across (batch*seq_len)-dimension.
As it requires non-trivial online loss computation, we cannot use fusing forward-backward technique in this case. However, memory footprint issue still can be addressed by slicing hidden state into small chunks. This is because we can do backward-step chunk by chunk instead of doing it all at once which results in high spike, although materializing all logits is inevitable.
I'm not sure the implementation of this PR is perfect for now, but just want to check this idea is valid and aligned with the spirit of liger-kernel. Any feedback would be great. Thanks.
Testing Done
I haven't run the tests yet, because I just want to check this concept is okay to be accepted.
- Hardware Type: <A100-SXM4-80GB>
- [ ] run
make testto ensure correctness - [ ] run
make checkstyleto ensure code style - [ ] run
make test-convergenceto ensure convergence