mamba icon indicating copy to clipboard operation
mamba copied to clipboard

RuntimeError: Selective_scan only supports state dimension <= 256

Open pprp opened this issue 1 year ago • 3 comments

When using mamba, I found the following RuntimeError:

  File "/data2/user/workspace/D-Mamba/d_mamba/models/mamba.py", line 171, in forward
    h = self.mamba(h)
  File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 146, in forward
    out = mamba_inner_fn(
  File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 306, in mamba_inner_fn
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 217, in forward
    out, scan_intermediates, out_z = selective_scan_cuda.fwd(
RuntimeError: selective_scan only supports state dimension <= 256

Here is the code I use:

from mamba_ssm.modules.mamba_simple import Block, Mamba

batch, length, dim = 16, 64, 512
x = torch.randn(batch, length, dim).to('cuda')
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim,  # Model dimension d_model; = `dim` in this case
    d_state=dim,  # SSM state expansion factor; should be the same to `dim`
    d_conv=4,  # Local convolution width
    expand=16,  # Block expansion factor; should be the same to `batch` size
).to('cuda')
y = model(x)
print(y.shape)

Why there is restriction for the dim? It is not proper for most use cases.

pprp avatar Jan 21 '24 13:01 pprp

I don't think the comments match the parameters they purport to describe in your Mamba instantiation. E.g. d_state is the number of states in the linear dynamic SS models, and should not in general "be the same as 'dim'". Nor do I see why expand should in general be the same as batch size.

The values you assign are radically different from those of the trained models the authors published. You have d_state = dim = 512, but d_state is 16 in the published models. You set expand=16, but it is 2 in the published models. I am doubtful that these values you are trying to use would make sense in most circumstances, at least without more explanation.

Selective scan is designed to exploit the memory cache hierarchy of GPUs; letting the SSM get too big will no longer fit the hardware.

Also, the dynamic models represented are many single-input single-output (SISO) dynamic systems corresponding to a system of uncoupled differential equation of order d_state=16. A 16-th order linear dynamic system is already a pretty high order (for a SISO), especially considering there are thousands of them per block. Letting it get more than an order of magnitude bigger would be of dubious value, at least in the use cases I've seen for transformer type sequence-to-sequence models.

jyegerlehner avatar Jan 21 '24 15:01 jyegerlehner

To put a finer point on it, the authors did ablation studies and the paper shows the effect of state size in table 10. There was not much improvement in perplexity in going from state size 8 to 16. So 16 is probably as big as makes sense in the case of language models.

jyegerlehner avatar Jan 21 '24 15:01 jyegerlehner

Yup, d_state here means each dimension of d_model will be expanded by a factor of d_state. We used d_state=16 in our experiments. Thanks @jyegerlehner for the detailed explanation.

tridao avatar Jan 21 '24 19:01 tridao

@jyegerlehner @tridao Thank you for your explanation!

pprp avatar Feb 05 '24 15:02 pprp