mamba icon indicating copy to clipboard operation
mamba copied to clipboard

why GPU memory cost of Mamba2 Block > full self-attention block? And how to reduce this memory cost when training?

Open thucz opened this issue 1 year ago • 1 comments

I found Mamba2 is much faster than full self-attention block. But I met a memory problem.

I used 12 layers of Mamba2 in the vision task. d_model=128, d_state=16, head_dim = 32, expand = 2.

I found the inner dimension of the middle layer is > d_model while the inner_dim in my self-attention block is equal to d_model.

image

How to reduce the inner dimension to have a similar or smaller memory cost than full self-attention?

thucz avatar Jul 26 '24 13:07 thucz

Have you tried reducing the expansion_ratio?

ScottHoang avatar Jul 30 '24 19:07 ScottHoang

@thucz Hi, I have the same problem. I am expecting in long sequences, mamba will cost much fewer memory than standard transformer block, but it seems not. Check my gist snippet here. I test the code on all newly built pytorch 2.6.0, mamba 2.2.2 and triton 3.0.0 on CUDA 11.8 and Nvidia A100. And Mamba has peak memory usage of 932.83MB which is larger than transformer one 722.90MB.

santisy avatar Oct 17 '24 19:10 santisy