Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Gradient checkpointing for `grad_weight` in LFCE

Open cassanof opened this issue 9 months ago • 4 comments

🚀 The feature, motivation and pitch

The LFCE kernel allocates a grad_weight tensor: https://github.com/linkedin/Liger-Kernel/blob/a8fa3bb37850e89500261024ff47da0c626ab75f/src/liger_kernel/ops/fused_linear_cross_entropy.py#L47 This tensor then gets updated throughout the chunked loss calculation and finally used in the backward as a custom grad operation: https://github.com/linkedin/Liger-Kernel/blob/a8fa3bb37850e89500261024ff47da0c626ab75f/src/liger_kernel/ops/fused_linear_cross_entropy.py#L127-L136

This has shape [vocab_size, hidden_size], which in situations where you have big models with big vocabularies, this becomes very large, which makes it impossible to do large pipeline parallel microbatches at long sequence lengths, as I have to keep this tensor in memory until the backward. It would be great to have gradient check-pointing here or even full recomputation would work.

Alternatives

No response

Additional context

No response

cassanof avatar Jan 20 '25 08:01 cassanof