mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Errors when running the test suite.

Open CompRhys opened this issue 1 year ago • 1 comments

________________________________________________________________________________________________ test_mamba_inner_fn[False-True-128-itype0-wtype0] ________________________________________________________________________________________________

is_variable_B = False, is_variable_C = True, seqlen = 128, itype = torch.float32, wtype = torch.float32

    @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
    # @pytest.mark.parametrize('wtype', [torch.complex64])
    # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
    @pytest.mark.parametrize('itype', [torch.float32])
    # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
    @pytest.mark.parametrize('seqlen', [128])
    @pytest.mark.parametrize("is_variable_C", [False, True])
    # @pytest.mark.parametrize("is_variable_C", [False])
    @pytest.mark.parametrize("is_variable_B", [False, True])
    # @pytest.mark.parametrize("is_variable_B", [True])
    def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
        device = 'cuda'
        rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
        if itype == torch.bfloat16:
            rtol, atol = 3e-2, 5e-2
        rtolw, atolw = (1e-3, 1e-3)
        # If we have z, the errors on the weights seem higher
        rtolw = max(rtolw, rtol)
        atolw = max(atolw, atol)
        # set seed
        torch.random.manual_seed(0)
        batch_size = 2
        dim = 768
        dstate = 8
        dt_rank = 48
        is_complex = wtype == torch.complex64
        xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
        conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
        conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
        x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
                                    * (1 if not is_complex else 2),
                                    dim, device=device, dtype=itype, requires_grad=True)
        delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
        out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
        out_proj_bias = None
        A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
        B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
             if not is_variable_B else None)
        C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
             if not is_variable_C else None)
        D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
        delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
        B_proj_bias = None
        C_proj_bias = None
        xz_ref = xz.detach().clone().requires_grad_()
        conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
        conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
        x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
        delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
        out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
        out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
                             if out_proj_bias is not None else None)
        A_ref = A.detach().clone().requires_grad_()
        B_ref = B.detach().clone().requires_grad_() if B is not None else None
        C_ref = C.detach().clone().requires_grad_() if C is not None else None
        D_ref = D.detach().clone().requires_grad_()
        delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
        out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                             out_proj_weight, out_proj_bias,
                             A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
>       out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
                                  delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
                                  A_ref, B_ref, C_ref, D_ref,
                                  delta_bias=delta_bias_ref, delta_softplus=True)

tests/ops/test_selective_scan.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mamba_ssm/ops/selective_scan_interface.py:321: in mamba_inner_ref
    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
/opt/miniconda/lib/python3.10/site-packages/causal_conv1d/causal_conv1d_interface.py:49: in causal_conv1d_fn
    return CausalConv1dFn.apply(x, weight, bias, seq_idx, activation)
/opt/miniconda/lib/python3.10/site-packages/torch/autograd/function.py:553: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

ctx = <torch.autograd.function.CausalConv1dFnBackward object at 0x7fa222223840>
x = tensor([[[-0.9247, -0.4253, -2.6438,  ..., -0.2128, -0.3315, -0.2023],
         [-1.1451, -0.5715, -0.6510,  ...,  1.3...      [-0.2289, -0.1726,  1.8851,  ..., -0.1589,  0.6690,  1.3431]]],
       device='cuda:0', grad_fn=<SplitBackward0>)
weight = tensor([[ 0.1808, -0.5523,  0.9238],
        [-0.7350,  1.3800,  0.8676],
        [ 0.1297, -0.9406,  0.8109],
       ...],
        [ 0.8140,  1.0932, -0.2314],
        [-0.2205, -0.9232, -1.6818]], device='cuda:0', grad_fn=<ViewBackward0>)
bias = tensor([ 2.5441e+00, -7.1635e-01, -4.9337e-01,  1.2671e-01,  1.0136e-01,
        -4.0353e-01,  9.0226e-01,  8.0993e-01...,  1.3356e+00, -1.1588e+00,
        -2.5133e-01, -1.3636e-01,  2.8971e-01], device='cuda:0',
       requires_grad=True)
seq_idx = 'silu', activation = None

    @staticmethod
    def forward(ctx, x, weight, bias=None, seq_idx=None, activation=None):
        if activation not in [None, "silu", "swish"]:
            raise NotImplementedError("activation must be None, silu, or swish")
        if x.stride(2) != 1 and x.stride(1) != 1:
            x = x.contiguous()
        bias = bias.contiguous() if bias is not None else None
>       seq_idx = seq_idx.contiguous() if seq_idx is not None else None
E       AttributeError: 'str' object has no attribute 'contiguous'

/opt/miniconda/lib/python3.10/site-packages/causal_conv1d/causal_conv1d_interface.py:19: AttributeError

Installing mamba-ssm and then running the test suite to check the install I got the preceding error. In total there were 8 failures in the test suite for the same .contiguous attribute error.

CompRhys avatar Feb 22 '24 19:02 CompRhys

I think this is caused by the causal_conv1d interface change: this commit

There is a PR trying to fix the bug

But for now I think install causal-conv1d<=1.0.2 might fix this?

BlenderWang9487 avatar Feb 23 '24 02:02 BlenderWang9487