jax
jax copied to clipboard
jax.scipy.sparse.linalg.cg inconsistent results between runs
Hi all,
the conjugate gradient function inside jax.scipy.sparse seems to be very inconsistent on jax GPU. I'm a new user to jax so im not sure if this issue has been addressed somewhere. I believe it is somewhat related to #565 and #9784.
To see the full picture, I have saved both input A and b so I can get consistent result between each runs. No preconditioning is applied.
I tested my result on three platforms: CPU[colab + local], GPU[colab + local] and TPU[colab].
Out of all the runs I have done, these three platforms all produce different results but only GPU has inconsistent issue between runs.
- On local machine, jax on CPU produces exactly the same result with colab CPU. And it is CONSISTENT between different runs.
- On colab, jax on TPU is also CONSISTENT between different runs.
- On GPU, both colab and my local machine has large INCONSISTENCY between runs. Sometimes even output a nan matrix.
I have seen people mention the issue with CUDA version, so I tested out cuda11.1, 11.2 and 11.4 and they all have the same issue.
To see how much changes it make, heres the output of three different runs:
DeviceArray([ 9.28246680e+03, 1.50545068e+04, 1.90608145e+04, 2.23634746e+04, 2.50702012e+04, 2.76033926e+04, 2.99257559e+04, 3.21613457e+04, 3.42872852e+04,...
DeviceArray([-8.13425984e-03, -1.17020588e-02, -1.27483038e-02, -1.18785836e-02, -9.67487786e-03, -6.41405629e-03, -2.11878261e-03, 3.24898120e-03, 9.95288976e-03,...
DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,...
I am using
jax 0.3.4
jaxlib 0.3.2+cuda11.cudnn82
scipy 1.8.0
numpy 1.22.3
FYI: here is a minimal example: https://colab.research.google.com/drive/1Z802HuUZ_TCTRxeQNRvC6XJppEiGAiyQ?zusp=sharing
For nearly singular matrices (e.g. conditioned graph Laplacians), cg returns different (sometimes valid, sometimes nan) solutions on GPU between different runs when x0 is not set. Is this expected? GMRES seems a bit more stable.
cond = np.linalg.cond(A.todense()) print(cond) 1.1747166e+17
It seems the condition number is very large. Is this what you expected?
Yes- so that's one thing. It's mainly an issue for poorly conditioned matrices. I guess typical and correct use cases wont suffer from this problem.
But the main thing is the inconsistency we see between cg on gpu and cg on cpu. Is there any reason that running cg on different devices leads to different results despite the same input?
I guess my question is less about finding a solution (although that would be great) and more verifying that this expected behavior (and not a sign of a deeper problem)?
Looks like the input is symmetric positive definite. I have checked
np.all(np.linalg.eigvals(A.todense()) > 0)
and
np.allclose(A_dense, A_dense.T.conj(), rtol=rtol, atol=atol)
Have you tried using float64?
The nondeterministic behavior may come from multithreading.
Was this issue resolved? @Dinple
I reran the colab repro.
On Colab GPU, I get values which are fairly consistently large and similar (with some occasional nan
, inf
)
[0.0000000e+00 9.8054398e+10 1.9872968e+11 6.2595138e+10 2.0462353e+11
2.4414941e+10 1.1615307e+11 7.3973126e+12 2.6439701e+13 4.8272494e+11]
[0.0000000e+00 2.8103875e+13 4.0418297e+13 4.0165729e+13 4.6595383e+13
4.9564275e+13 4.1783241e+13 9.9750607e+13 3.7144668e+13 4.0531598e+13]
[0.0000000e+00 2.3296439e+11 9.7694744e+13 4.8008027e+11 1.4699356e+11
7.5435302e+12 9.5971885e+10 9.2001908e+11 2.6396608e+11 9.4231685e+11]
[0.0000000e+00 3.2017095e+11 2.1751036e+13 4.9183644e+11 2.9387155e+11
2.1419085e+12 3.1061000e+11 1.8692791e+11 1.2441883e+13 4.7176105e+12]
[0.0000000e+00 5.7433942e+11 3.3398978e+12 3.7421702e+11 2.7343428e+12
1.1824465e+12 5.5289119e+11 2.3275011e+13 8.5892014e+10 3.8059177e+12]
[0.0000000e+00 4.9893282e+10 1.5721413e+11 7.4679969e+13 3.8181389e+12
1.5548703e+10 1.5535525e+12 2.3955174e+11 1.4923589e+13 nan]
[0.0000000e+00 6.1637303e+11 5.2090693e+11 6.1480265e+11 6.0722269e+11
inf 2.5805787e+13 4.8744081e+11 6.9813535e+11 3.0362653e+13]
[0.0000000e+00 2.8956185e+13 8.2401290e+11 7.0151563e+11 9.3606137e+11
4.9036857e+11 2.1517848e+12 1.5174398e+12 6.1058523e+12 7.8914683e+11]
[0.0000000e+00 2.8998506e+12 3.1845496e+11 3.1819612e+11 inf
1.7479309e+13 1.3300310e+11 2.3351959e+13 7.2910045e+10 1.4908412e+13]
[0.0000000e+00 6.4928147e+11 8.0207806e+11 1.4854672e+12 7.3838559e+11
1.3758692e+12 8.9137250e+11 1.3934976e+12 9.1503893e+11 2.3913224e+12]
On Colab TPU, I get consistently all zeros:
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
After some debugging, I found out that cuSPARSE isn't being called.
The from_scipy_sparse
function sets indices_sorted=False
and then in _bcoo_dot_general_gpu_lowering
the _bcoo_dot_general_default_lowering
function is being called based on the indices_sorted
value. So the issue is in _bcoo_dot_general_impl
.
Sorting the indices (by adding A = A.sort_indices()
) makes the code call cuSPARSE and then the results are consistent.
Thanks for looking into this! We recently enable the lowering of BCOO dot_general to cuSparse (https://github.com/google/jax/pull/12138). Yes, indices_sorted=True
is one of the requirements for using cuSparse.
@tlu7
Just another remark - currently jax uses CUSPARSE_MV_ALG_DEFAULT
as a parameter to cusparse spmv in files https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse_kernels.cc and https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse.cc , which is deprecated and might default to non-deterministic result in general. I would suggest using CUSPARSE_SPMV_COO_ALG2
instead - according to docs https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmv :
Provides deterministic (bit-wise) results for each run. If opA != CUSPARSE_OPERATION_NON_TRANSPOSE, it is identical to CUSPARSE_SPMV_COO_ALG1
Thanks for the suggestions! Can you also share insights on the cusparse matmat algorithms? Which one shall we use as the default for jax? @marsaev
https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmm
@jakevdp do have a suggestion what may be wrong in _bcoo_dot_general_impl? Is it the same algorithm that is being used for TPUs/CPUs?
@tlu7
- If determinism (reproducibility) is needed (which i assume true for JAX), then there is only one option - for SpMV use
CUSPARSE_SPMV_COO_ALG2
andCUSPARSE_SPMV_CSR_ALG2
, for SpMM useCUSPARSE_SPMM_COO_ALG2
andCUSPARSE_SPMM_CSR_ALG3
according to the matrix format. - If there is no such requirements - unfortunately there is no heuristics available, only guidance from the docs for those functions, i.e. https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmm . There are also studies like https://arxiv.org/pdf/2202.08556.pdf that show other users experience with different algorithm values.
@marsaev @fbusato
Thank you. A follow-up question about the cuda versions for these new algorithms.
I found this in the release notes
[2.5.12. cuSPARSE: Release 11.2 Update 1](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cusparse-11.2.1)
...
New algorithms for CSR/COO Sparse Matrix - Vector Multiplication (cusparseSpMV) with better performance.
....
New algorithm (CUSPARSE_SPMM_CSR_ALG3) for Sparse Matrix - Matrix Multiplication
...
Shall I assume that all these algorithms are available of cuda 11.2 and onwards. Is there any document that I can find this information? I need the version information to make sure those JAX routines are backward compatible.
@marsaev @fbusato
Comparing to the default algorithms for SpMV and SpMM, do the four algorithms that provide determinism have trade-offs on accuracy? I broke a few accuracy tests due to the change from the default algorithms to the four aforementioned algorithms.
Hi @tlu7,
Shall I assume that all these algorithms are available of cuda 11.2 and onwards. Is there any document that I can find this information? I need the version information to make sure those JAX routines are backward compatible.
Yes, these algorithm enumerators are compatible with CUDA 11.x and 12.x. It could change in CUDA 13.x.
Comparing to the default algorithms for SpMV and SpMM, do the four algorithms that provide determinism have trade-offs on accuracy? I broke a few accuracy tests due to the change from the default algorithms to the four aforementioned algorithms.
no, we cannot say that one algorithm is more accurate than another one
Thanks @fbusato !
Can you share more information on the versions in 11.x when the four algorithms become available?
It seems CUSPARSE_SPMM_CSR_ALG3 is in since 11.2 and the other three is unclear from the release notes.
2.5.12. cuSPARSE: Release 11.2 Update 1 ... New algorithms for CSR/COO Sparse Matrix - Vector Multiplication (cusparseSpMV) with better performance. .... New algorithm (CUSPARSE_SPMM_CSR_ALG3) for Sparse Matrix - Matrix Multiplication ...
There is a small trick that you can use to check old toolkit documentations 😀 https://developer.nvidia.com/cuda-toolkit-archive
CUSPARSE_SPMM_CSR_ALG3
and SpMV algorithms have been introduced in CUDA 11.2u1 https://docs.nvidia.com/cuda/archive/11.2.1/cusparse/index.html#cusparse-generic-function-spmm
Thanks! It works like a charm. need some patience though :)