mamba icon indicating copy to clipboard operation
mamba copied to clipboard

[Feature] Support variable-length sequences for mamba block

Open zigzagcai opened this issue 11 months ago • 40 comments

Support variable-length sequences for mamba block via cu_seqlens in the forward pass and backward pass, similar to what has been done (such as cumulative sequences cu_seqlens or lower triangular block diagonal matrix attention mask) in flash attention varlen_fwd/varlen_bwd API.

We have tested that training with variable-length sequences on real world datasets can bring 2~4x speedup.

  • Why we need? High speedup and hardware utilization on real world datasets that we tested. Can be used to improve hardware utilization when you have variable-length sequences and you don't want to waste computing resources on meaningless padded tokens. Especially useful when you do mamba training on real world datasets, where length distribution varies much and large proportion of samples are short sequences. Last but not least, we ensure exact fwd/bwd numerical equality with padding approach.

  • How to use? Zero learning overhead, packed mamba API is similar to packed flash-attn API or packed mamba2 API. Just need to pack multiple variable-length sequences into one and additionally pass cu_seqlens into mamba forward pass.

Note: We thank @wang-zerui for the fwd pass python reference implementation and discussion on how to ensure numerical equality. This is a joint work with @wang-zerui and @Dmovic and @ptxu78

Some related issues about mamba and flash-attn variable-length training:

  1. https://github.com/state-spaces/mamba/issues/236
  2. https://github.com/state-spaces/mamba/issues/356
  3. https://github.com/state-spaces/mamba/issues/180
  4. https://github.com/state-spaces/mamba/issues/246#issuecomment-2003017621
  5. https://github.com/Dao-AILab/flash-attention/issues/850#issuecomment-1980308347
  6. https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1698610752

zigzagcai avatar Mar 14 '24 08:03 zigzagcai