jax
jax copied to clipboard
Unit test "tests/x64_context_test.py::X64ContextTests::test_near_singular_inverse" fails when the input matrix is batched
Please apply the following patch to run the "test_near_singular_inverse" with batched inverse:
diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py
index d7a5a7bd..7c1bee24 100644
--- a/tests/x64_context_test.py
+++ b/tests/x64_context_test.py
@@ -84,8 +84,8 @@ class X64ContextTests(jtu.JaxTestCase):
self.skipTest("64-bit inverse not available on TPU")
@partial(_maybe_jit, jit, static_argnums=1)
def near_singular_inverse(key, N, eps):
- X = random.uniform(key, (N, N))
- X = X.at[-1].mul(eps)
+ X = random.uniform(key, (2, N, N))
+ X = X.at[:, -1].mul(eps)
return jnp.linalg.inv(X)
key = random.PRNGKey(1701)
and then run the unit test as following. The test will eventually fails.
pytest tests/x64_context_test.py::X64ContextTests::test_near_singular_inverse_jit=None
I believe this issue comes from the fact that JAX uses XLA implementation for unbatched version and internal implementation for batched version (please look here). I also think the main cause of this issue is that for some reason the ftz (flush to zero) is enabled in the XLA path. However, I found that ftz is not enabled by default for XLA path as well (please look here). I have tried to dump the xla outputs via "XLA_FLAGS=--xla_dump_to=
Hi - thanks for the report. The code you linked to is in the GPU translation rule. Just to confirm: are you running this on a GPU?
Now that I look at it, there is a similar pattern for CPU.
Can you say more about what issue this difference in implementation between batched and unbatched results is causing for you?
Here's a more to-the-point demonstration of the difference in behavior between batched & unbatched inverse (run on colab CPU):
import jax.numpy as jnp
x = jnp.array([[[1, 2, 3],
[4, 5, 6],
[0, 0, 0]]])
#Batched: lowers to XLA
print(jnp.linalg.inv(x)[0])
# [[-1.6666666 0.6666666 1. ]
# [ 1.3333334 -0.33333334 -2. ]
# [ 0. 0. 1. ]]
# Unbatched: lowers to lapack
print(jnp.linalg.inv(x[0]))
# [[ nan nan nan]
# [ nan nan -inf]
# [ nan nan inf]]
Hi Jake,
Hi - thanks for the report. The code you linked to is in the GPU translation rule. Just to confirm: are you running this on a GPU?
Yes, I am running this on GPU. I have access to both AMD and Nvidia GPUs.
Now that I look at it, there is a similar pattern for CPU.
Can you say more about what issue this difference in implementation between batched and unbatched results is causing for you?
I am working for AMD - ML framework team. My recent project is to support JAX on ROCm stack. I first saw this issue when I tried to run X64ContextTests unit tests on ROCm and AMD GPUs. After running these tests (unbatched) on Nvidia and AMD GPUs, I found Nvidia passed the test, but AMD did not. I tried to track the divergence between these two and ended up with identifying this difference between batched and unbatched version.
@jakevdp Isn't the matrix you've specified singular? It doesn't seem like it has a well-defined inverse, so any output is equally wrong. But the original issue seems legit.
Also, it looks like for the batched case we're just calling to cusolver/rocsolver, so ultimately this might just point to numerical instability inside those libraries. I think that the only thing we could do is try always using the XLA implementation (which does seem to support batching). @hawkinsp given the numerical instability, should we keep using cusolver?
Yes, it's singular. My interpretation of the issue is that it boils down to batched vs non-batched inverses using different algorithms that handle ill-posed inputs differently; my example was meant to demonstrate that more succinctly. If you're interested in the non-singular ill-posed version, you can replace the zeros with 1E-40
Also, it looks like for the batched case we're just calling to cusolver/rocsolver, so ultimately this might just point to numerical instability inside those libraries. I think that the only thing we could do is try always using the XLA implementation (which does seem to support batching). @hawkinsp given the numerical instability, should we keep using cusolver?
I agree that the cusolver path with fp32 leads to numerical instability. If I run the same example with fp64, it leads to a different result compared to the fp32.
import jax.numpy as jnp
from jax.config import config; config.update("jax_enable_x64", True)
x_f32 = jnp.array([[[1, 2, 3],
[4, 5, 6],
[7E-41, 8E-41, 9E-41]],
[[1, 2, 3],
[4, 5, 6],
[7E-41, 8E-41, 9E-41]
]], dtype=jnp.float32)
print(jnp.linalg.inv(x_f32)[0])
# [[ -3568.167 7136.166 nan]
# [ 7134.3335 -14271.333 inf]
# [ -3566.5 7135.5 -inf]]
x_f64 = jnp.array([[[1, 2, 3],
[4, 5, 6],
[7E-41, 8E-41, 9E-41]],
[[1, 2, 3],
[4, 5, 6],
[7E-41, 8E-41, 9E-41]
]], dtype=jnp.float64)
print(jnp.linalg.inv(x_f64)[0])
# [[-1.56927543e+16 3.13855087e+16 -1.56927543e+57]
# [ 3.13855087e+16 -6.27710174e+16 3.13855087e+57]
# [-1.56927543e+16 3.13855087e+16 -1.56927543e+57]]
Just to add some context: Imho, this bug has two distinct problems
Roughly speaking, for a matrix A with a condition number of k, the inverse can only be computed up to accuracy k,
||real_inverse(A) - computed_inverse(A)|| < k u ||real_inverse(A)||
with u the machine accuracy (1e-8 in FP32; 1e-16 in FP64).
The actual accuracy will depend on the algorithm; row-pivoted LU is not provably stable, but in practice it almost always is, so the above should be true, especially for small sizes [see 3.4.5 in Golub & Van Loan 4th edition].
Now, consider the matrices from https://github.com/google/jax/issues/7054#issuecomment-867007844, which have a condition number of k = O(1/eps) ~= 1e40.
-
The matrices are, effectively, numerically singular. In FP32, any matrix with a condition number > 1e8 is effectively singular; in FP64, any matrix with a condition number > 1e16 is effectively singular. This means there is no right or wrong answer - the problem is numerically ill-posed and "any" output is in-correct.
-
That being said, the code paths for batched & non-batched matrices do give different results. In my case (JAX 0.3.15+cuda11.cudnn82)
import jax.numpy as jnp
x = jnp.array([[[1, 2, 3],
[4, 5, 6],
[1e-40, 1e-40, 1e-40]]])
print("Batched: ", jnp.linalg.inv(x)[0])
print("Non batched: ", jnp.linalg.inv(x[0]))
GPU:
$ python3 bug_7024_2.py
Batched: [[ nan nan nan]
[ inf -inf inf]
[-inf inf -inf]]
Non batched: [[ nan nan nan]
[ inf -inf inf]
[-inf inf -inf]]
CPU:
$ CUDA_VISIBLE_DEVICES="" python3 bug_7024_2.py
2022-08-22 14:53:42.760111: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Batched: [[-1.6666666 0.6666666 1. ]
[ 1.3333334 -0.33333334 -2. ]
[ 0. 0. 1. ]]
Non batched: [[ nan nan nan]
[ nan nan -inf]
[ nan nan inf]]
Considering (1) and (2) above, the correct thing to do in terms of implementation and testing depends of what JAX is expected to give in terms of guarantees.
- If JAX is expected to give the same output for batched and non-batched, then yes there is a bug. The only way to guarantee the exact same result for batched and non-batched, even for singular or near-singular matrices, is to run the exact same algorithm. Any deviation can lead to arbitrarily large changes in the output. This is not a bug, this is a fundamental property of the mathematical problem being solved.
- If JAX is expected to give "any" sensible result for batched and non-batched, then
epsshould probably be a small multiple of the machine accuracyu(maybe 1e-6 in FP32 or 1e-14 in FP64). Then, we could check that||Output_batched - Output_non_batched|| < (1/eps) * u * ||Output_non_batched||.
I haven't had a chance to look at this bug in detail yet, but that sounds right.
It is certainly that we expect (2) not (1). Batching won't preserve numerics exactly on pretty much any hardware: the way we compute, say, a matrix-vector multiplication is usually with a different order of operations to, say, a matrix-matrix multiplication. In the presence of things like TF32 math, they might have very different numerics. So we can only really sensibly say (2) not (1).
I agree that the problem is with the test.