mamba icon indicating copy to clipboard operation
mamba copied to clipboard

selective_scan_cuda

Open EddieEduardo opened this issue 1 year ago • 17 comments

Hi, great work!!!

How can I look into the implementation of selective_scan_cuda.fwd and bwd?

out, scan_intermediates, out_z = selective_scan_cuda.fwd(
            conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
        )

Looking forward to your reply.

EddieEduardo avatar Oct 24 '24 03:10 EddieEduardo

Hi @EddieEduardo, selective_scan_cuda is a CUDA C++ extension. The fwd (forward) and bwd (backward) functions are defined via CUDA kernels in the selective_scan_fwd_kernel.cuh and selective_scan_bwd_kernel.cuh files, respectively.

This extension is built while installing mamba (from source), as per the setup in setup.py:

https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/setup.py#L236-L254

JeS24 avatar Oct 24 '24 04:10 JeS24

Thanks for your reply!

I would like to know what the three returned values (out, scan_intermediates, out_z) represent, respectively?

EddieEduardo avatar Oct 24 '24 06:10 EddieEduardo

same question!

blank-Cd avatar Oct 31 '24 06:10 blank-Cd

The out is the selective scan output (i.e. typical mamba if we are not fusing in the skip connect). Lets say that out is of the following form:

$$out = SSM(inputs)$$

Then scan_intermediates are the hidden state values partway through the scan at regular intervals. These are saved for the backward pass so that for really long sequences it doesn't take that long. So for example we can save the hidden state of a sequence of 2048 long for every 256 tokens to get 8 intermediate states. These are then used in the backward pass for optimization purposes. Finally, out_z will just be the SSM output multiplied by a gated tensor z if you have one. This would then be essentially:

$$out_z = out \cdot \sigma(z)$$

where $\sigma$ is a softplus activation. This is just again a way to speed things up by fusing the gated multiplication into the SSM operation. Hope this helps, and feel free to ping me if you have other questions about it.

Hprairie avatar Nov 05 '24 01:11 Hprairie

@Hprairie I am trying to export/compile the selective_scan at https://github.com/pytorch/pytorch/issues/130150#issuecomment-2543153878

Do you have an exportable/compilable version?

bhack avatar Dec 16 '24 00:12 bhack

I don't have one right now (might in a few weeks with a CUDA implementation for Mamba2). I know the lingua repo from FAIR has one that I have tested and works. It should be fairly simple to wrap it though.

Hprairie avatar Dec 16 '24 02:12 Hprairie

Lingua repo hasn't selective_scsn custom_op: https://github.com/facebookresearch/lingua/tree/main/apps%2Fmamba%2Fcomponent

bhack avatar Dec 16 '24 09:12 bhack

Ahh, you're right they only have mamba2. It should be fairly easy to set it up. I looked at your error in pytorch#130150 and your problem is that you are trying to cast a list with tensor in it, into a tensor. You should just need to unpack it correctly.

Hprairie avatar Dec 16 '24 18:12 Hprairie

you are trying to cast a list with tensor in it, into a tensor. You should just need to unpack it correctly.

Yes I found that one but it seems I have a chain of issue correctly wrapping your selective_scan in a regular custom_op.

Can we just hotswap selective_scan with mamba_chunk_scan_combined? I want to try to export this Neurips 2024 model and this selective_scan it is an issue without a custom_op: https://github.com/MCG-NJU/VFIMamba/blob/main/model/feature_extractor.py#L225

bhack avatar Dec 16 '24 19:12 bhack

I would just write a new custom_op around this function. Creating a function for both the forward and backward. Remember to register fake tensor examples and then a context function, but then it should be easy to do. If you are not in a super rush, I will probably have a good implementation in a few weeks.

Hprairie avatar Dec 16 '24 19:12 Hprairie

I have at least a running draft. Let me know if you want to review with a PR in this repo.

bhack avatar Dec 17 '24 04:12 bhack

The out is the selective scan output (i.e. typical mamba if we are not fusing in the skip connect). Lets say that out is of the following form:

o u t = S S M ( i n p u t s )

Then scan_intermediates are the hidden state values partway through the scan at regular intervals. These are saved for the backward pass so that for really long sequences it doesn't take that long. So for example we can save the hidden state of a sequence of 2048 long for every 256 tokens to get 8 intermediate states. These are then used in the backward pass for optimization purposes. Finally, out_z will just be the SSM output multiplied by a gated tensor z if you have one. This would then be essentially:

o u t z = o u t ⋅ σ ( z )

where σ is a softplus activation. This is just again a way to speed things up by fusing the gated multiplication into the SSM operation. Hope this helps, and feel free to ping me if you have other questions about it.

In the Mamba and Vision Mamba papers, the activation function is stated as SiLU, but in the code, the authors used Softplus. Please clarify if my understanding is incorrect.

poult-lab avatar Mar 13 '25 14:03 poult-lab

SiLU is used so that the output to the mamba layer is gated. Softplus is used for calculating delta.

Hprairie avatar Mar 13 '25 14:03 Hprairie

SiLU is used so that the output to the mamba layer is gated. Softplus is used for calculating delta.

Thank you so much for your prompt reply! Your response was very helpful. As you mentioned, SiLU is used as the activation function for the gated mechanism.

In the following function call:

out, scan_intermediates, out_z = selective_scan_cuda.fwd(
    conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)

out_z is the output of Mamba after being processed by the gated mechanism. Could you please tell me where SiLU is applied within the function selective_scan_cuda.fwd()?

poult-lab avatar Mar 14 '25 01:03 poult-lab

Look at the CUDA kernel, it's at the very end before write out :)

Hprairie avatar Mar 22 '25 03:03 Hprairie

SiLU is used so that the output to the mamba layer is gated. Softplus is used for calculating delta.

Thank you for your reply. I have another question, what is the difference between y = selective_scan_fn(x, dt, A, B, C, self.D.float(), z=None, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=None) and out, scan_intermediates, out_z = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ? it seems they both are served as out=SSM(inputs)

poult-lab avatar Mar 25 '25 08:03 poult-lab

how can I turn the checkpoint .ckpt into pt, so i can use it by libtorch, while selective_scan_cuda is a CUDA C++ extension, which can not be recognized by torch.jit.scritpt?

cebain avatar Jul 24 '25 09:07 cebain