mamba
mamba copied to clipboard
Error with Mamba2
Hi, I just try the test code of mamba-2 like this:
from mamba_ssm import Mamba2 import torch batch, length, dim = 2, 64, 1024 x = torch.randn(batch, length, dim).to("cuda") model = Mamba2( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=64, # SSM state expansion factor, typically 64 or 128 d_conv=4, # Local convolution width expand=2, # Block expansion factor headdim=128 ).to("cuda") y = model(x) assert y.shape == x.shape print("Mamba2 model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) print('x.shape:', x.shape, 'y.shape:', y.shape)
But there are some errors:
` File "/opt/anaconda3/envs/medfusion-2d/lib/python3.8/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 761, in forward causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported: 1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor Invoked with: tensor([[[ 0.6263, -0.1259, 0.6615, ..., 0.1121, 0.1023, -0.2840], [-0.5732, 1.5656, 0.5829, ..., 0.6564, 0.7546, 0.1331], [ 0.4265, -0.1785, 0.1311, ..., 0.6014, -1.0048, 0.0453], ..., [ 0.1693, -0.7641, -0.0408, ..., -0.3669, -0.2489, -0.2052], [ 0.8796, -0.5051, 0.3856, ..., 0.6248, 0.2461, -0.6594], [-0.6611, 0.2886, 0.4760, ..., -0.0319, 0.6962, -1.1070]],
[[ 0.3243, 0.7392, -0.6660, ..., -0.2669, -0.3460, 0.1921],
[-0.1172, 0.2228, -0.1020, ..., 1.1721, 2.1293, 0.4847],
[ 0.0962, 0.2899, -0.6043, ..., -0.6814, 0.4837, 0.0075],
...,
[ 0.1357, -1.0081, 0.3166, ..., -0.4532, 0.9043, -0.1286],
[ 0.6356, 0.1391, -0.3242, ..., 0.3308, 0.3722, -0.5956],
[ 0.7242, -0.3001, 0.8165, ..., 0.5277, 1.1039, -0.9327]]],
device='cuda:0', requires_grad=True), tensor([[-0.1640, 0.4310, -0.2341, 0.2770],
[ 0.1296, -0.1512, 0.0115, 0.1537],
[-0.0655, 0.3352, 0.2952, -0.3224],
...,
[-0.2745, 0.0135, 0.3997, -0.2371],
[ 0.4181, -0.0019, 0.1142, 0.1713],
[-0.3888, 0.3710, 0.4792, 0.2264]], device='cuda:0',
grad_fn=<ViewBackward0>), Parameter containing:
tensor([-0.3444, -0.2064, -0.3750, ..., 0.2153, -0.1905, -0.0108], device='cuda:0', requires_grad=True), None, None, None, True `