mamba icon indicating copy to clipboard operation
mamba copied to clipboard

The Accuracy Problem of Mamba Operator

Open Unrealluver opened this issue 1 year ago • 4 comments

Greetings!

Thanks for your awesome work! The GPU-level optimization of the mamba operator is impressive to me. But I face the accuracy problem when trying the unit test in running mamba/tests/ops/test_selective_scan.py. I got the output below:

mamba_inner_fn
Output max diff: 0.0
Output mean diff: 0.0
dxz max diff: 512.0
dA max diff: 131072.0
dD max diff: 0.0003662109375
ddelta_bias max diff: 8.0
dout_proj_weight max diff: 80.0
ddelta_proj_weight max diff: 1024.0
dx_proj_weight max diff: 320.0
dconv1d_weight max diff: 864.0
dconv1d_bias max diff: 704.0

It is worth noticing that the variables' gradients, such as dxz, dA, ddelta_bias, etc., are different with the reference mamba implement. Could you share some reasons for this? And how could we judge the mamba operator's reliability from what kind of results? Looking forward to your reply~

Unrealluver avatar Dec 21 '23 15:12 Unrealluver

I have similar issue. This is my result


---------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------
Output max diff: 0.25
Output mean diff: 0.005645751953125
==================================================================== short test summary info =====================================================================
FAILED tests/ops/triton/test_selective_state_update.py::test_causal_conv1d_update[2048-64-True-itype2] - AssertionError: assert False
FAILED tests/ops/triton/test_selective_state_update.py::test_causal_conv1d_update[4096-64-True-itype2] - AssertionError: assert False
================================================================= 2 failed, 72 passed in 20.46s ==================================================================

It seems there are some correctness issue

jacklishufan avatar Dec 29 '23 00:12 jacklishufan

I have the similar question:

test_mamba_inner_fn(True, True, 128, torch.float32, torch.float32)

Then I get:

Output max diff: 0.0
Output mean diff: 0.0
dxz max diff: 320.0
dA max diff: 0.0
dD max diff: 0.0
ddelta_bias max diff: 0.0
dout_proj_weight max diff: 0.0
ddelta_proj_weight max diff: 0.0
dx_proj_weight max diff: 312.0
dconv1d_weight max diff: 512.0
dconv1d_bias max diff: 320.0

tyshiwo1 avatar Jan 20 '24 17:01 tyshiwo1

i'm getting similar results:

for example for the test test_mamba_inner_fn with is_variable_B=True, is_variable_C=True, seqlen=128, itype=torch.float32, wtype=torch.complex64

Output max diff: 0.0
Output mean diff: 0.0
dxz max diff: 5888.0
dx max diff: 5888.0
dz max diff: 1056.0
dA max diff: 1417760.5
dD max diff: 0.00042724609375
ddelta_bias max diff: 1024.0
dout_proj_weight max diff: 0.0
ddelta_proj_weight max diff: 98304.0
dx_proj_weight max diff: 11264.0
dconv1d_weight max diff: 20480.0
dconv1d_bias max diff: 28672.0

wondering whether anyone has an explanation for this?

TangTangFei avatar Apr 10 '24 06:04 TangTangFei

i have changed the scale of the following tensors:

xz = (0.01 * torch.rand(bs, 2 * dim, seq_len, dtype=itype, device=device)).requires_grad_()
g = torch.randn_like(out) * 0.01

the discrepancies become rather small then, possibly the large differences in the gradients were due to some numerical instabilities when the numbers become too large.

TangTangFei avatar Apr 12 '24 10:04 TangTangFei