mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Test failure when dim is large

Open bilgeacun opened this issue 11 months ago • 1 comments

Hi,

I'm running the mamba test_selective_scan.py benchmark with increasing the model dimension and the tests starts to fail. Here is how I increase the dimension:

diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py
index 8a834b3..2cf04e0 100644
--- a/tests/ops/test_selective_scan.py
+++ b/tests/ops/test_selective_scan.py
@@ -50,7 +50,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z
     # set seed
     torch.random.manual_seed(0)
     batch_size = 2
-    dim = 4
+    dim = 4096
     dstate = 8
     is_complex = wtype == torch.complex64
     A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()

And here is the result of pytest test_selective_scan.py -s :

        print(f'Output max diff: {(out - out_ref).abs().max().item()}')
        print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
>       assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x7fe3c3414640>(tensor([[[ 3.9609e-03, -3.3309e+00, -5.4975e-01,  ..., -2.1817e+00,\n           2.3821e-01,  3.5165e-01],\n         [ 3....0,  ...,  3.1667e-01,\n          -4.9747e-01, -1.8112e-02]]], device='cuda:0',\n       grad_fn=<SelectiveScanFnBackward>), tensor([[[ 3.9609e-03, -3.3309e+00, -5.4975e-01,  ..., -2.1817e+00,\n           2.3821e-01,  3.5165e-01],\n         [ 3....43e-02, -1.7847e+00,  ...,  3.1667e-01,\n          -4.9747e-01, -1.8111e-02]]], device='cuda:0', grad_fn=<MulBackward0>), rtol=0.0006, atol=0.002)
E        +    where <built-in method allclose of type object at 0x7fe3c3414640> = torch.allclose

test_selective_scan.py:114: AssertionError
====================================================== short test summary info ======================================================
FAILED test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-256-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-1024-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-2048-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-4096-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-512-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-2048-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-4096-itype0-wtype0] - AssertionError: assert False

It seems there is a correctness issue in the kernel implementation. Could you take a look please?

I'm running causal-conv1d v1.1.0 and mamba-ssm v1.1.1.

Thanks!

bilgeacun avatar Mar 19 '24 19:03 bilgeacun