mamba
mamba copied to clipboard
RuntimeError: Selective_scan only supports state dimension <= 256
When using mamba
, I found the following RuntimeError:
File "/data2/user/workspace/D-Mamba/d_mamba/models/mamba.py", line 171, in forward
h = self.mamba(h)
File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 146, in forward
out = mamba_inner_fn(
File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 306, in mamba_inner_fn
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
return fwd(*args, **kwargs)
File "/home/user/miniconda3/envs/py39/lib/python3.9/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 217, in forward
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
RuntimeError: selective_scan only supports state dimension <= 256
Here is the code I use:
from mamba_ssm.modules.mamba_simple import Block, Mamba
batch, length, dim = 16, 64, 512
x = torch.randn(batch, length, dim).to('cuda')
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model; = `dim` in this case
d_state=dim, # SSM state expansion factor; should be the same to `dim`
d_conv=4, # Local convolution width
expand=16, # Block expansion factor; should be the same to `batch` size
).to('cuda')
y = model(x)
print(y.shape)
Why there is restriction for the dim? It is not proper for most use cases.
I don't think the comments match the parameters they purport to describe in your Mamba instantiation. E.g. d_state is the number of states in the linear dynamic SS models, and should not in general "be the same as 'dim'". Nor do I see why expand should in general be the same as batch size.
The values you assign are radically different from those of the trained models the authors published. You have d_state = dim = 512, but d_state is 16 in the published models. You set expand=16, but it is 2 in the published models. I am doubtful that these values you are trying to use would make sense in most circumstances, at least without more explanation.
Selective scan is designed to exploit the memory cache hierarchy of GPUs; letting the SSM get too big will no longer fit the hardware.
Also, the dynamic models represented are many single-input single-output (SISO) dynamic systems corresponding to a system of uncoupled differential equation of order d_state=16. A 16-th order linear dynamic system is already a pretty high order (for a SISO), especially considering there are thousands of them per block. Letting it get more than an order of magnitude bigger would be of dubious value, at least in the use cases I've seen for transformer type sequence-to-sequence models.
To put a finer point on it, the authors did ablation studies and the paper shows the effect of state size in table 10. There was not much improvement in perplexity in going from state size 8 to 16. So 16 is probably as big as makes sense in the case of language models.
Yup, d_state here means each dimension of d_model will be expanded by a factor of d_state
. We used d_state=16 in our experiments.
Thanks @jyegerlehner for the detailed explanation.
@jyegerlehner @tridao Thank you for your explanation!