mamba
mamba copied to clipboard
Errors when running the test suite.
________________________________________________________________________________________________ test_mamba_inner_fn[False-True-128-itype0-wtype0] ________________________________________________________________________________________________
is_variable_B = False, is_variable_C = True, seqlen = 128, itype = torch.float32, wtype = torch.float32
@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
# @pytest.mark.parametrize('wtype', [torch.complex64])
# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize('itype', [torch.float32])
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
@pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("is_variable_C", [False, True])
# @pytest.mark.parametrize("is_variable_C", [False])
@pytest.mark.parametrize("is_variable_B", [False, True])
# @pytest.mark.parametrize("is_variable_B", [True])
def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
device = 'cuda'
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
rtolw, atolw = (1e-3, 1e-3)
# If we have z, the errors on the weights seem higher
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# set seed
torch.random.manual_seed(0)
batch_size = 2
dim = 768
dstate = 8
dt_rank = 48
is_complex = wtype == torch.complex64
xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
* (1 if not is_complex else 2),
dim, device=device, dtype=itype, requires_grad=True)
delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
out_proj_bias = None
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
if not is_variable_B else None)
C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
if not is_variable_C else None)
D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
B_proj_bias = None
C_proj_bias = None
xz_ref = xz.detach().clone().requires_grad_()
conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
if out_proj_bias is not None else None)
A_ref = A.detach().clone().requires_grad_()
B_ref = B.detach().clone().requires_grad_() if B is not None else None
C_ref = C.detach().clone().requires_grad_() if C is not None else None
D_ref = D.detach().clone().requires_grad_()
delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
> out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
A_ref, B_ref, C_ref, D_ref,
delta_bias=delta_bias_ref, delta_softplus=True)
tests/ops/test_selective_scan.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mamba_ssm/ops/selective_scan_interface.py:321: in mamba_inner_ref
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
/opt/miniconda/lib/python3.10/site-packages/causal_conv1d/causal_conv1d_interface.py:49: in causal_conv1d_fn
return CausalConv1dFn.apply(x, weight, bias, seq_idx, activation)
/opt/miniconda/lib/python3.10/site-packages/torch/autograd/function.py:553: in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = <torch.autograd.function.CausalConv1dFnBackward object at 0x7fa222223840>
x = tensor([[[-0.9247, -0.4253, -2.6438, ..., -0.2128, -0.3315, -0.2023],
[-1.1451, -0.5715, -0.6510, ..., 1.3... [-0.2289, -0.1726, 1.8851, ..., -0.1589, 0.6690, 1.3431]]],
device='cuda:0', grad_fn=<SplitBackward0>)
weight = tensor([[ 0.1808, -0.5523, 0.9238],
[-0.7350, 1.3800, 0.8676],
[ 0.1297, -0.9406, 0.8109],
...],
[ 0.8140, 1.0932, -0.2314],
[-0.2205, -0.9232, -1.6818]], device='cuda:0', grad_fn=<ViewBackward0>)
bias = tensor([ 2.5441e+00, -7.1635e-01, -4.9337e-01, 1.2671e-01, 1.0136e-01,
-4.0353e-01, 9.0226e-01, 8.0993e-01..., 1.3356e+00, -1.1588e+00,
-2.5133e-01, -1.3636e-01, 2.8971e-01], device='cuda:0',
requires_grad=True)
seq_idx = 'silu', activation = None
@staticmethod
def forward(ctx, x, weight, bias=None, seq_idx=None, activation=None):
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
> seq_idx = seq_idx.contiguous() if seq_idx is not None else None
E AttributeError: 'str' object has no attribute 'contiguous'
/opt/miniconda/lib/python3.10/site-packages/causal_conv1d/causal_conv1d_interface.py:19: AttributeError
Installing mamba-ssm and then running the test suite to check the install I got the preceding error. In total there were 8 failures in the test suite for the same .contiguous
attribute error.
I think this is caused by the causal_conv1d interface change: this commit
There is a PR trying to fix the bug
But for now I think install causal-conv1d<=1.0.2
might fix this?