Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

[GRPO] chunk over vocab without materializing logits

Open kashif opened this issue 2 weeks ago • 0 comments

Summary

Updating the forward pass to compute only the required per-token log probabilities, simplifying the loss function interface, and adding comprehensive tests to ensure correctness against the Triton implementation:

  • The chunk_forward method in fused_linear_ppo.py now computes log probabilities only for selected tokens (using selected_token_ids), avoiding allocation of large [B, T, V] tensors and instead returning [B, T] tensors for per-token log probabilities. This greatly reduces memory usage, especially for large vocabularies.
  • The loss computation in _compute_chunk_loss is updated to use these per-token log probabilities directly, and the interface for the loss function is changed accordingly (from log_probs to per_token_logps). [1] [2]

Simplification and correctness improvements:

  • The ppo_loss_fn in grpo_loss.py is simplified: it now expects pre-gathered per-token log probabilities, removing the need for an internal .gather() operation and unnecessary handling of full log probability tensors.
  • Redundant arguments and code paths for handling full vocabulary log probabilities are removed, further streamlining the code.

Testing and validation:

  • A comprehensive test, test_chunked_vs_triton_grpo_loss, is added to ensure that the chunked, memory-optimized loss matches the Triton kernel implementation across a range of configurations, including different batch sizes, sequence lengths, hidden sizes, vocab sizes, loss types, and hyperparameters. This test checks per-token losses, KL divergences, clipping indicators, and reduced losses for correctness.

Testing Done

  • Hardware Type: <BLANK>
  • [ ] run make test to ensure correctness
  • [ ] run make checkstyle to ensure code style
  • [ ] run make test-convergence to ensure convergence

kashif avatar Nov 26 '25 08:11 kashif