mamba icon indicating copy to clipboard operation
mamba copied to clipboard

add `mamba_chunk_scan_combined` and `mamba_split_conv1d_scan_combined` tests

Open garrett361 opened this issue 1 year ago • 5 comments

This PR adds correctness tests for mamba_chunk_scan_combined and mamba_split_conv1d_scan_combined, which seemed to be missing. Forwards and backwards are tested against their reference implementations. Correctness when providing seq_idx is also tested.

garrett361 avatar Jan 14 '25 19:01 garrett361

@tridao I know the kernels inside of mamba_chunk_scan_combined and mamba_split_conv1d_scan_combined are individually tested, but I thought it would be worth it to add these more end-to-end tests. Thoughts?p

garrett361 avatar Jan 21 '25 19:01 garrett361

Any idea why the tolerances need to be that high? Those tolerances seem very high for float32. It is probably related to #683 #571

peterbjorgensen avatar Apr 09 '25 09:04 peterbjorgensen

Yes, concerningly high, at least for the backwards where some tests need tol = 1e-1 and/or are sensitive to seeds.

My first suspicion was that it is an issue with the tests, rather than the kernels, but I haven't found any problems yet. And since the forwards tests pass at reasonable-ish 1e-2/1e-3 levels, any error would need to be a bit subtle.

I have also found some non-determinism with the backwards passes for the D grads. Haven't posted about it yet; will try to today.

garrett361 avatar Apr 09 '25 14:04 garrett361

Also, also this is relevant: non-determinism is expected in the backwards due to atomic adds, apparently.

garrett361 avatar Apr 09 '25 14:04 garrett361

Any idea why the tolerances need to be that high? Those tolerances seem very high for float32. It is probably related to #683 #571

Hi, thanks for mentioning this. I posted a solution for my case in #571 , you might want to check that. I was able to manage tolerances upto 1e-8 for all gradients and outputs.

karannb avatar Apr 15 '25 21:04 karannb