Liger-Kernel
Liger-Kernel copied to clipboard
Gradient checkpointing for `grad_weight` in LFCE
🚀 The feature, motivation and pitch
The LFCE kernel allocates a grad_weight tensor:
https://github.com/linkedin/Liger-Kernel/blob/a8fa3bb37850e89500261024ff47da0c626ab75f/src/liger_kernel/ops/fused_linear_cross_entropy.py#L47
This tensor then gets updated throughout the chunked loss calculation and finally used in the backward as a custom grad operation:
https://github.com/linkedin/Liger-Kernel/blob/a8fa3bb37850e89500261024ff47da0c626ab75f/src/liger_kernel/ops/fused_linear_cross_entropy.py#L127-L136
This has shape [vocab_size, hidden_size], which in situations where you have big models with big vocabularies, this becomes very large, which makes it impossible to do large pipeline parallel microbatches at long sequence lengths, as I have to keep this tensor in memory until the backward. It would be great to have gradient check-pointing here or even full recomputation would work.
Alternatives
No response
Additional context
No response