Tri Dao
Tri Dao
Did you follow the installation instruction in the README? It's here and should be installed if you do `pip install mamba-ssm`: https://github.com/state-spaces/mamba/tree/main/csrc/selective_scan
Can you try downloading that wheel URL manually to check if the networking works?
v1.2.1 now includes wheels for pytorch 2.3 so it should hopefully fix this issue
Yup, d_state here means each dimension of d_model will be expanded by a factor of `d_state`. We used d_state=16 in our experiments. Thanks @jyegerlehner for the detailed explanation.
> I'm thinking of backpropagation through time on multiple chunks (e.g. 4) instead of fitting full sequence in one huge window Would that take the same amount of activation memory...
I don't get why there would be a spike in memory. Are you saying it happens on consumer grade GPUs and not on data center GPUs? I don't think we...
For layer norm specifically, what if you use torch.nn.Layernorm? If the memory blows up there then it's a general problem and not specific to Mamba model.
I've no idea why there would be a memory spike, if you figure out let me know.
We use the work-efficient version (Blelloch's scan).
It's a sequence-to-sequence mapping, with input (batch, seqlen, dim) and output (batch, seqlen, dim). You can use it in the same way you'd use any other sequence-to-sequence layer (e.g. attention).