mamba
mamba copied to clipboard
Test failure when dim is large
Hi,
I'm running the mamba test_selective_scan.py benchmark with increasing the model dimension and the tests starts to fail. Here is how I increase the dimension:
diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py
index 8a834b3..2cf04e0 100644
--- a/tests/ops/test_selective_scan.py
+++ b/tests/ops/test_selective_scan.py
@@ -50,7 +50,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z
# set seed
torch.random.manual_seed(0)
batch_size = 2
- dim = 4
+ dim = 4096
dstate = 8
is_complex = wtype == torch.complex64
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
And here is the result of pytest test_selective_scan.py -s
:
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
> assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
E AssertionError: assert False
E + where False = <built-in method allclose of type object at 0x7fe3c3414640>(tensor([[[ 3.9609e-03, -3.3309e+00, -5.4975e-01, ..., -2.1817e+00,\n 2.3821e-01, 3.5165e-01],\n [ 3....0, ..., 3.1667e-01,\n -4.9747e-01, -1.8112e-02]]], device='cuda:0',\n grad_fn=<SelectiveScanFnBackward>), tensor([[[ 3.9609e-03, -3.3309e+00, -5.4975e-01, ..., -2.1817e+00,\n 2.3821e-01, 3.5165e-01],\n [ 3....43e-02, -1.7847e+00, ..., 3.1667e-01,\n -4.9747e-01, -1.8111e-02]]], device='cuda:0', grad_fn=<MulBackward0>), rtol=0.0006, atol=0.002)
E + where <built-in method allclose of type object at 0x7fe3c3414640> = torch.allclose
test_selective_scan.py:114: AssertionError
====================================================== short test summary info ======================================================
FAILED test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-256-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-1024-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-2048-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-1-True-True-True-True-True-4096-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-512-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-2048-itype0-wtype0] - AssertionError: assert False
FAILED test_selective_scan.py::test_selective_scan[True-True-2-True-True-True-True-True-4096-itype0-wtype0] - AssertionError: assert False
It seems there is a correctness issue in the kernel implementation. Could you take a look please?
I'm running causal-conv1d v1.1.0 and mamba-ssm v1.1.1.
Thanks!