why GPU memory cost of Mamba2 Block > full self-attention block? And how to reduce this memory cost when training?
I found Mamba2 is much faster than full self-attention block. But I met a memory problem.
I used 12 layers of Mamba2 in the vision task. d_model=128, d_state=16, head_dim = 32, expand = 2.
I found the inner dimension of the middle layer is > d_model while the inner_dim in my self-attention block is equal to d_model.
How to reduce the inner dimension to have a similar or smaller memory cost than full self-attention?
Have you tried reducing the expansion_ratio?
@thucz Hi, I have the same problem. I am expecting in long sequences, mamba will cost much fewer memory than standard transformer block, but it seems not. Check my gist snippet here. I test the code on all newly built pytorch 2.6.0, mamba 2.2.2 and triton 3.0.0 on CUDA 11.8 and Nvidia A100. And Mamba has peak memory usage of 932.83MB which is larger than transformer one 722.90MB.