mamba
mamba copied to clipboard
[Feature] Support variable-length sequences for mamba block
Support variable-length sequences for mamba block via cu_seqlens
in the forward
pass and backward
pass, similar to what has been done (such as cumulative sequences cu_seqlens
or lower triangular block diagonal matrix attention mask
) in flash attention varlen_fwd/varlen_bwd
API.
We have tested that training with variable-length sequences on real world datasets can bring 2~4x speedup.
-
Why we need? High speedup and hardware utilization on real world datasets that we tested. Can be used to improve hardware utilization when you have variable-length sequences and you don't want to waste computing resources on meaningless padded tokens. Especially useful when you do mamba training on real world datasets, where length distribution varies much and large proportion of samples are short sequences. Last but not least, we ensure exact fwd/bwd numerical equality with padding approach.
-
How to use? Zero learning overhead, packed mamba API is similar to packed flash-attn API or packed mamba2 API. Just need to pack multiple variable-length sequences into one and additionally pass
cu_seqlens
into mambaforward
pass.
Note: We thank @wang-zerui for the fwd pass python reference implementation and discussion on how to ensure numerical equality. This is a joint work with @wang-zerui and @Dmovic and @ptxu78
Some related issues about mamba and flash-attn variable-length training:
- https://github.com/state-spaces/mamba/issues/236
- https://github.com/state-spaces/mamba/issues/356
- https://github.com/state-spaces/mamba/issues/180
- https://github.com/state-spaces/mamba/issues/246#issuecomment-2003017621
- https://github.com/Dao-AILab/flash-attention/issues/850#issuecomment-1980308347
- https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1698610752