Liger-Kernel
Liger-Kernel copied to clipboard
[GRPO] chunk over vocab without materializing logits
Summary
Updating the forward pass to compute only the required per-token log probabilities, simplifying the loss function interface, and adding comprehensive tests to ensure correctness against the Triton implementation:
- The
chunk_forwardmethod infused_linear_ppo.pynow computes log probabilities only for selected tokens (usingselected_token_ids), avoiding allocation of large[B, T, V]tensors and instead returning[B, T]tensors for per-token log probabilities. This greatly reduces memory usage, especially for large vocabularies. - The loss computation in
_compute_chunk_lossis updated to use these per-token log probabilities directly, and the interface for the loss function is changed accordingly (fromlog_probstoper_token_logps). [1] [2]
Simplification and correctness improvements:
- The
ppo_loss_fningrpo_loss.pyis simplified: it now expects pre-gathered per-token log probabilities, removing the need for an internal.gather()operation and unnecessary handling of full log probability tensors. - Redundant arguments and code paths for handling full vocabulary log probabilities are removed, further streamlining the code.
Testing and validation:
- A comprehensive test,
test_chunked_vs_triton_grpo_loss, is added to ensure that the chunked, memory-optimized loss matches the Triton kernel implementation across a range of configurations, including different batch sizes, sequence lengths, hidden sizes, vocab sizes, loss types, and hyperparameters. This test checks per-token losses, KL divergences, clipping indicators, and reduced losses for correctness.
Testing Done
- Hardware Type: <BLANK>
- [ ] run
make testto ensure correctness - [ ] run
make checkstyleto ensure code style - [ ] run
make test-convergenceto ensure convergence