mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Error with Mamba2

Open Adele0108 opened this issue 1 year ago • 3 comments

Hi, I just try the test code of mamba-2 like this: from mamba_ssm import Mamba2 import torch batch, length, dim = 2, 64, 1024 x = torch.randn(batch, length, dim).to("cuda") model = Mamba2( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=64, # SSM state expansion factor, typically 64 or 128 d_conv=4, # Local convolution width expand=2, # Block expansion factor headdim=128 ).to("cuda") y = model(x) assert y.shape == x.shape print("Mamba2 model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) print('x.shape:', x.shape, 'y.shape:', y.shape)

But there are some errors:

` File "/opt/anaconda3/envs/medfusion-2d/lib/python3.8/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 761, in forward causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported: 1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor Invoked with: tensor([[[ 0.6263, -0.1259, 0.6615, ..., 0.1121, 0.1023, -0.2840], [-0.5732, 1.5656, 0.5829, ..., 0.6564, 0.7546, 0.1331], [ 0.4265, -0.1785, 0.1311, ..., 0.6014, -1.0048, 0.0453], ..., [ 0.1693, -0.7641, -0.0408, ..., -0.3669, -0.2489, -0.2052], [ 0.8796, -0.5051, 0.3856, ..., 0.6248, 0.2461, -0.6594], [-0.6611, 0.2886, 0.4760, ..., -0.0319, 0.6962, -1.1070]],

    [[ 0.3243,  0.7392, -0.6660,  ..., -0.2669, -0.3460,  0.1921],
     [-0.1172,  0.2228, -0.1020,  ...,  1.1721,  2.1293,  0.4847],
     [ 0.0962,  0.2899, -0.6043,  ..., -0.6814,  0.4837,  0.0075],
     ...,
     [ 0.1357, -1.0081,  0.3166,  ..., -0.4532,  0.9043, -0.1286],
     [ 0.6356,  0.1391, -0.3242,  ...,  0.3308,  0.3722, -0.5956],
     [ 0.7242, -0.3001,  0.8165,  ...,  0.5277,  1.1039, -0.9327]]],
   device='cuda:0', requires_grad=True), tensor([[-0.1640,  0.4310, -0.2341,  0.2770],
    [ 0.1296, -0.1512,  0.0115,  0.1537],
    [-0.0655,  0.3352,  0.2952, -0.3224],
    ...,
    [-0.2745,  0.0135,  0.3997, -0.2371],
    [ 0.4181, -0.0019,  0.1142,  0.1713],
    [-0.3888,  0.3710,  0.4792,  0.2264]], device='cuda:0',
   grad_fn=<ViewBackward0>), Parameter containing:

tensor([-0.3444, -0.2064, -0.3750, ..., 0.2153, -0.1905, -0.0108], device='cuda:0', requires_grad=True), None, None, None, True `

Adele0108 avatar Jun 20 '24 13:06 Adele0108

Please update causal_conv1d.

tridao avatar Jun 21 '24 03:06 tridao

Thanks for your prompt answer. After update, there is a new error: File "/opt/anaconda3/envs/medfusion-2d/lib/python3.8/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 761, in forward causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

Adele0108 avatar Jun 21 '24 08:06 Adele0108

Have you solved this updated problem ,I also meet it:causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

Peilin-FF avatar Jul 24 '24 01:07 Peilin-FF

Have you solved this updated problem ,I also meet it:causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

buAUFDkru avatar Nov 04 '24 08:11 buAUFDkru

请更新causal_conv1d。

Traceback (most recent call last): File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/mnt/e/project2/accelerated_features-main/modules/training/train.py", line 330, in trainer.train() File "/mnt/e/project2/accelerated_features-main/modules/training/train.py", line 236, in train feats1, kpts1, hmap1 = self.net(p1) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/mnt/e/project2/accelerated_features-main/modules/model5.py", line 154, in forward x3 = self.block3(x2) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/mnt/e/project2/accelerated_features-main/modules/LightManbaXfeatNet.py", line 136, in forward x = self.conv1(x) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/mnt/e/project2/accelerated_features-main/modules/LightManbaXfeatNet.py", line 35, in forward x_mamba = self.mamba(x_norm) + self.skip_scale * x_flat File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in call_impl return forward_call(*args, **kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py", line 183, in forward out = mamba_split_conv1d_scan_combined( File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 930, in mamba_split_conv1d_scan_combined return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj bias, headdim, ngroups, norm_before_gate) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd return fwd(*args, **kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 779, in forward causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8 what is stride?what can I do? Shape of x: torch.Size([3, 7600, 64]) Strides of x: (486400, 64, 1)

buAUFDkru avatar Nov 04 '24 08:11 buAUFDkru