selective_scan_cuda
Hi, great work!!!
How can I look into the implementation of selective_scan_cuda.fwd and bwd?
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
Looking forward to your reply.
Hi @EddieEduardo, selective_scan_cuda is a CUDA C++ extension. The fwd (forward) and bwd (backward) functions are defined via CUDA kernels in the selective_scan_fwd_kernel.cuh and selective_scan_bwd_kernel.cuh files, respectively.
This extension is built while installing mamba (from source), as per the setup in setup.py:
https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/setup.py#L236-L254
Thanks for your reply!
I would like to know what the three returned values (out, scan_intermediates, out_z) represent, respectively?
same question!
The out is the selective scan output (i.e. typical mamba if we are not fusing in the skip connect). Lets say that out is of the following form:
$$out = SSM(inputs)$$
Then scan_intermediates are the hidden state values partway through the scan at regular intervals. These are saved for the backward pass so that for really long sequences it doesn't take that long. So for example we can save the hidden state of a sequence of 2048 long for every 256 tokens to get 8 intermediate states. These are then used in the backward pass for optimization purposes. Finally, out_z will just be the SSM output multiplied by a gated tensor z if you have one. This would then be essentially:
$$out_z = out \cdot \sigma(z)$$
where $\sigma$ is a softplus activation. This is just again a way to speed things up by fusing the gated multiplication into the SSM operation. Hope this helps, and feel free to ping me if you have other questions about it.
@Hprairie I am trying to export/compile the selective_scan at https://github.com/pytorch/pytorch/issues/130150#issuecomment-2543153878
Do you have an exportable/compilable version?
I don't have one right now (might in a few weeks with a CUDA implementation for Mamba2). I know the lingua repo from FAIR has one that I have tested and works. It should be fairly simple to wrap it though.
Lingua repo hasn't selective_scsn custom_op: https://github.com/facebookresearch/lingua/tree/main/apps%2Fmamba%2Fcomponent
Ahh, you're right they only have mamba2. It should be fairly easy to set it up. I looked at your error in pytorch#130150 and your problem is that you are trying to cast a list with tensor in it, into a tensor. You should just need to unpack it correctly.
you are trying to cast a list with tensor in it, into a tensor. You should just need to unpack it correctly.
Yes I found that one but it seems I have a chain of issue correctly wrapping your selective_scan in a regular custom_op.
Can we just hotswap selective_scan with mamba_chunk_scan_combined?
I want to try to export this Neurips 2024 model and this selective_scan it is an issue without a custom_op:
https://github.com/MCG-NJU/VFIMamba/blob/main/model/feature_extractor.py#L225
I would just write a new custom_op around this function. Creating a function for both the forward and backward. Remember to register fake tensor examples and then a context function, but then it should be easy to do. If you are not in a super rush, I will probably have a good implementation in a few weeks.
I have at least a running draft. Let me know if you want to review with a PR in this repo.
The
outis the selective scan output (i.e. typical mamba if we are not fusing in the skip connect). Lets say thatoutis of the following form:o u t = S S M ( i n p u t s )
Then
scan_intermediatesare the hidden state values partway through the scan at regular intervals. These are saved for the backward pass so that for really long sequences it doesn't take that long. So for example we can save the hidden state of a sequence of 2048 long for every 256 tokens to get 8 intermediate states. These are then used in the backward pass for optimization purposes. Finally,out_zwill just be the SSM output multiplied by a gated tensorzif you have one. This would then be essentially:o u t z = o u t ⋅ σ ( z )
where σ is a softplus activation. This is just again a way to speed things up by fusing the gated multiplication into the SSM operation. Hope this helps, and feel free to ping me if you have other questions about it.
In the Mamba and Vision Mamba papers, the activation function is stated as SiLU, but in the code, the authors used Softplus. Please clarify if my understanding is incorrect.
SiLU is used so that the output to the mamba layer is gated. Softplus is used for calculating delta.
SiLU is used so that the output to the mamba layer is gated. Softplus is used for calculating delta.
Thank you so much for your prompt reply! Your response was very helpful. As you mentioned, SiLU is used as the activation function for the gated mechanism.
In the following function call:
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
out_z is the output of Mamba after being processed by the gated mechanism. Could you please tell me where SiLU is applied within the function selective_scan_cuda.fwd()?
Look at the CUDA kernel, it's at the very end before write out :)
SiLU is used so that the output to the mamba layer is gated. Softplus is used for calculating delta.
Thank you for your reply. I have another question,
what is the difference between y = selective_scan_fn(x, dt, A, B, C, self.D.float(), z=None, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=None) and out, scan_intermediates, out_z = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ?
it seems they both are served as out=SSM(inputs)
how can I turn the checkpoint .ckpt into pt, so i can use it by libtorch, while selective_scan_cuda is a CUDA C++ extension, which can not be recognized by torch.jit.scritpt?