jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.scipy.sparse.linalg.cg inconsistent results between runs

Open Dinple opened this issue 2 years ago • 6 comments

Hi all,

the conjugate gradient function inside jax.scipy.sparse seems to be very inconsistent on jax GPU. I'm a new user to jax so im not sure if this issue has been addressed somewhere. I believe it is somewhat related to #565 and #9784.

To see the full picture, I have saved both input A and b so I can get consistent result between each runs. No preconditioning is applied.

I tested my result on three platforms: CPU[colab + local], GPU[colab + local] and TPU[colab].

Out of all the runs I have done, these three platforms all produce different results but only GPU has inconsistent issue between runs.

  • On local machine, jax on CPU produces exactly the same result with colab CPU. And it is CONSISTENT between different runs.
  • On colab, jax on TPU is also CONSISTENT between different runs.
  • On GPU, both colab and my local machine has large INCONSISTENCY between runs. Sometimes even output a nan matrix.

I have seen people mention the issue with CUDA version, so I tested out cuda11.1, 11.2 and 11.4 and they all have the same issue.

To see how much changes it make, heres the output of three different runs: DeviceArray([ 9.28246680e+03, 1.50545068e+04, 1.90608145e+04, 2.23634746e+04, 2.50702012e+04, 2.76033926e+04, 2.99257559e+04, 3.21613457e+04, 3.42872852e+04,...

DeviceArray([-8.13425984e-03, -1.17020588e-02, -1.27483038e-02, -1.18785836e-02, -9.67487786e-03, -6.41405629e-03, -2.11878261e-03, 3.24898120e-03, 9.95288976e-03,...

DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,...

I am using jax 0.3.4
jaxlib 0.3.2+cuda11.cudnn82 scipy 1.8.0 numpy 1.22.3

Dinple avatar Apr 13 '22 00:04 Dinple

FYI: here is a minimal example: https://colab.research.google.com/drive/1Z802HuUZ_TCTRxeQNRvC6XJppEiGAiyQ?zusp=sharing

For nearly singular matrices (e.g. conditioned graph Laplacians), cg returns different (sometimes valid, sometimes nan) solutions on GPU between different runs when x0 is not set. Is this expected? GMRES seems a bit more stable.

choltz95 avatar Apr 13 '22 01:04 choltz95

cond = np.linalg.cond(A.todense()) print(cond) 1.1747166e+17

It seems the condition number is very large. Is this what you expected?

tlu7 avatar Apr 29 '22 16:04 tlu7

Yes- so that's one thing. It's mainly an issue for poorly conditioned matrices. I guess typical and correct use cases wont suffer from this problem.

But the main thing is the inconsistency we see between cg on gpu and cg on cpu. Is there any reason that running cg on different devices leads to different results despite the same input?

I guess my question is less about finding a solution (although that would be great) and more verifying that this expected behavior (and not a sign of a deeper problem)?

choltz95 avatar Apr 29 '22 18:04 choltz95

Looks like the input is symmetric positive definite. I have checked np.all(np.linalg.eigvals(A.todense()) > 0) and np.allclose(A_dense, A_dense.T.conj(), rtol=rtol, atol=atol)

Have you tried using float64?

tlu7 avatar Apr 29 '22 18:04 tlu7

The nondeterministic behavior may come from multithreading.

tlu7 avatar Apr 29 '22 23:04 tlu7

Was this issue resolved? @Dinple

sudhakarsingh27 avatar Aug 08 '22 20:08 sudhakarsingh27

I reran the colab repro.

On Colab GPU, I get values which are fairly consistently large and similar (with some occasional nan, inf)

[0.0000000e+00 9.8054398e+10 1.9872968e+11 6.2595138e+10 2.0462353e+11
 2.4414941e+10 1.1615307e+11 7.3973126e+12 2.6439701e+13 4.8272494e+11]

[0.0000000e+00 2.8103875e+13 4.0418297e+13 4.0165729e+13 4.6595383e+13
 4.9564275e+13 4.1783241e+13 9.9750607e+13 3.7144668e+13 4.0531598e+13]
 
[0.0000000e+00 2.3296439e+11 9.7694744e+13 4.8008027e+11 1.4699356e+11
 7.5435302e+12 9.5971885e+10 9.2001908e+11 2.6396608e+11 9.4231685e+11]

[0.0000000e+00 3.2017095e+11 2.1751036e+13 4.9183644e+11 2.9387155e+11
 2.1419085e+12 3.1061000e+11 1.8692791e+11 1.2441883e+13 4.7176105e+12]

[0.0000000e+00 5.7433942e+11 3.3398978e+12 3.7421702e+11 2.7343428e+12
 1.1824465e+12 5.5289119e+11 2.3275011e+13 8.5892014e+10 3.8059177e+12]

[0.0000000e+00 4.9893282e+10 1.5721413e+11 7.4679969e+13 3.8181389e+12
 1.5548703e+10 1.5535525e+12 2.3955174e+11 1.4923589e+13           nan]

[0.0000000e+00 6.1637303e+11 5.2090693e+11 6.1480265e+11 6.0722269e+11
           inf 2.5805787e+13 4.8744081e+11 6.9813535e+11 3.0362653e+13]

[0.0000000e+00 2.8956185e+13 8.2401290e+11 7.0151563e+11 9.3606137e+11
 4.9036857e+11 2.1517848e+12 1.5174398e+12 6.1058523e+12 7.8914683e+11]

[0.0000000e+00 2.8998506e+12 3.1845496e+11 3.1819612e+11           inf
 1.7479309e+13 1.3300310e+11 2.3351959e+13 7.2910045e+10 1.4908412e+13]

[0.0000000e+00 6.4928147e+11 8.0207806e+11 1.4854672e+12 7.3838559e+11
 1.3758692e+12 8.9137250e+11 1.3934976e+12 9.1503893e+11 2.3913224e+12]

On Colab TPU, I get consistently all zeros:

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

sudhakarsingh27 avatar Oct 14 '22 21:10 sudhakarsingh27

After some debugging, I found out that cuSPARSE isn't being called. The from_scipy_sparse function sets indices_sorted=False and then in _bcoo_dot_general_gpu_lowering the _bcoo_dot_general_default_lowering function is being called based on the indices_sorted value. So the issue is in _bcoo_dot_general_impl. Sorting the indices (by adding A = A.sort_indices()) makes the code call cuSPARSE and then the results are consistent.

almogsegal avatar Oct 19 '22 11:10 almogsegal

Thanks for looking into this! We recently enable the lowering of BCOO dot_general to cuSparse (https://github.com/google/jax/pull/12138). Yes, indices_sorted=True is one of the requirements for using cuSparse.

tlu7 avatar Oct 19 '22 17:10 tlu7

@tlu7 Just another remark - currently jax uses CUSPARSE_MV_ALG_DEFAULT as a parameter to cusparse spmv in files https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse_kernels.cc and https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse.cc , which is deprecated and might default to non-deterministic result in general. I would suggest using CUSPARSE_SPMV_COO_ALG2 instead - according to docs https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmv :

Provides deterministic (bit-wise) results for each run. If opA != CUSPARSE_OPERATION_NON_TRANSPOSE, it is identical to CUSPARSE_SPMV_COO_ALG1

marsaev avatar Oct 20 '22 14:10 marsaev

Thanks for the suggestions! Can you also share insights on the cusparse matmat algorithms? Which one shall we use as the default for jax? @marsaev

https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmm

tlu7 avatar Oct 20 '22 17:10 tlu7

@jakevdp do have a suggestion what may be wrong in _bcoo_dot_general_impl? Is it the same algorithm that is being used for TPUs/CPUs?

almogsegal avatar Oct 26 '22 08:10 almogsegal

@tlu7

  1. If determinism (reproducibility) is needed (which i assume true for JAX), then there is only one option - for SpMV use CUSPARSE_SPMV_COO_ALG2 and CUSPARSE_SPMV_CSR_ALG2, for SpMM use CUSPARSE_SPMM_COO_ALG2 and CUSPARSE_SPMM_CSR_ALG3 according to the matrix format.
  2. If there is no such requirements - unfortunately there is no heuristics available, only guidance from the docs for those functions, i.e. https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmm . There are also studies like https://arxiv.org/pdf/2202.08556.pdf that show other users experience with different algorithm values.

marsaev avatar Oct 26 '22 08:10 marsaev

@marsaev @fbusato

Thank you. A follow-up question about the cuda versions for these new algorithms.

I found this in the release notes

[2.5.12. cuSPARSE: Release 11.2 Update 1](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cusparse-11.2.1)
...
New algorithms for CSR/COO Sparse Matrix - Vector Multiplication (cusparseSpMV) with better performance.
....
New algorithm (CUSPARSE_SPMM_CSR_ALG3) for Sparse Matrix - Matrix Multiplication
...

Shall I assume that all these algorithms are available of cuda 11.2 and onwards. Is there any document that I can find this information? I need the version information to make sure those JAX routines are backward compatible.

tlu7 avatar Oct 26 '22 17:10 tlu7

@marsaev @fbusato

Comparing to the default algorithms for SpMV and SpMM, do the four algorithms that provide determinism have trade-offs on accuracy? I broke a few accuracy tests due to the change from the default algorithms to the four aforementioned algorithms.

tlu7 avatar Oct 27 '22 04:10 tlu7

Hi @tlu7,

Shall I assume that all these algorithms are available of cuda 11.2 and onwards. Is there any document that I can find this information? I need the version information to make sure those JAX routines are backward compatible.

Yes, these algorithm enumerators are compatible with CUDA 11.x and 12.x. It could change in CUDA 13.x.

Comparing to the default algorithms for SpMV and SpMM, do the four algorithms that provide determinism have trade-offs on accuracy? I broke a few accuracy tests due to the change from the default algorithms to the four aforementioned algorithms.

no, we cannot say that one algorithm is more accurate than another one

fbusato avatar Oct 27 '22 18:10 fbusato

Thanks @fbusato !

Can you share more information on the versions in 11.x when the four algorithms become available?

It seems CUSPARSE_SPMM_CSR_ALG3 is in since 11.2 and the other three is unclear from the release notes.

2.5.12. cuSPARSE: Release 11.2 Update 1 ... New algorithms for CSR/COO Sparse Matrix - Vector Multiplication (cusparseSpMV) with better performance. .... New algorithm (CUSPARSE_SPMM_CSR_ALG3) for Sparse Matrix - Matrix Multiplication ...

tlu7 avatar Oct 27 '22 18:10 tlu7

There is a small trick that you can use to check old toolkit documentations 😀 https://developer.nvidia.com/cuda-toolkit-archive CUSPARSE_SPMM_CSR_ALG3 and SpMV algorithms have been introduced in CUDA 11.2u1 https://docs.nvidia.com/cuda/archive/11.2.1/cusparse/index.html#cusparse-generic-function-spmm

fbusato avatar Oct 27 '22 18:10 fbusato

Thanks! It works like a charm. need some patience though :)

tlu7 avatar Oct 27 '22 18:10 tlu7