mamba icon indicating copy to clipboard operation
mamba copied to clipboard

feat: Initial state support for Mamba SSM (1)

Open mzusman opened this issue 1 year ago • 4 comments

Add chunked prefill / use initial state capability to Mamba ssm ( Mamba 1 ) , Done it by prepending the last forward pass state to the FWD pass kernel and read the data accordingly .

Latency is not affected. ( benchmark script shows similar latencies between this PR and main - 130ms ) Added tests that check correctness when running on chunks.

Limitations:

  • Applied only for selective scan fwd pass ( bwd pass is not supported )

This PR enables efficient Speculative decoding, prefix caching and prefill chunking.

FIX #233 #473 #258 #101

mzusman avatar Jul 24 '24 09:07 mzusman

@mzusman I've noticed you made changes to files in the csrc directory, but I'm having trouble getting these changes to take effect in my environment. Could you please tell me the exact instructions to rebuild and install the mamba_ssm package so the changes are applied? It seems I always get the original package using pip install .Thank you!

daphneOdera-618 avatar Sep 02 '24 17:09 daphneOdera-618

@daphneOdera-618 Yeah, the default setup.py behaviour is to download the upstream's wheel upon "installing", What you would need to do to force build is to add MAMBA_FORCE_BUILD=TRUE pip install .

mzusman avatar Sep 03 '24 07:09 mzusman

Unfortunately, this PR changes the API for selective_scan_cuda.fwd in an incompatible way. The same API is also invoked in MambaInnerFn.forward besides of SelectiveScanFn.forward, leading to runtime errors in code which uses MambaInnerFn (e.g. the Mamba implementation found in the transformers library while running in vanilla training mode without cache_params).

I think MambaInnerFn.forward could be modified to use the new API version, but I don't know how to produce the prerequisite additional empty vector (x) from what is available in MambaInnerFn.fowrard.

jploski avatar Jun 09 '25 14:06 jploski

Unfortunately, this PR changes the API for selective_scan_cuda.fwd in an incompatible way. The same API is also invoked in MambaInnerFn.forward besides of SelectiveScanFn.forward, leading to runtime errors in code which uses MambaInnerFn (e.g. the Mamba implementation found in the transformers library while running in vanilla training mode without cache_params).

I think MambaInnerFn.forward could be modified to use the new API version, but I don't know how to produce the prerequisite additional empty vector (x) from what is available in MambaInnerFn.fowrard.

Since conv1d_out in MambeInnerFn seems to play the same role as u in SelectiveScanFn, adding this hack in place of the original invocation of selective_scan_cuda.fwd seems to work:

        u = conv1d_out
        n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
        _x = torch.zeros(
            (u.shape[0], u.shape[1], n_chunks, int(A.shape[1] * 2),),
            device=u.device,
            dtype=torch.float32,
            requires_grad=u.requires_grad
        )
        _x[:, :, 0, 0::2] = 1
#        if prev_state is not None:
#            _x[:, :, 0, 1::2].copy_(prev_state)
        out, scan_intermediates, out_z = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, _x)

jploski avatar Jun 09 '25 14:06 jploski