jax
jax copied to clipboard
SVD does not propagate NaNs for batch sizes >2
Description
Running an SVD on a matrix full of NaNs is expected to return NaNs in U, S and Vt, but only S consistently propagates NaNs. When running an SVD on a matrix of (2, 3, 3) full of NaNs, U unexpectedly becomes a batch of two eye(3) matrices, and Vt is a batch of 2 -eye(3) matrices.
import jax.numpy as jnp
print(jnp.linalg.svd(jnp.full((1, 3, 3), jnp.nan))) # All NaN as expected
print(jnp.linalg.svd(jnp.full((2, 3, 3), jnp.nan))) # U and Vt contain eye and -eye matrices, only S is NaN
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.8.1
jaxlib: 0.8.1
numpy: 2.3.5
python: 3.13.5 | packaged by conda-forge | (main, Jun 16 2025, 08:27:50) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='amacati-workstation', release='6.8.0-88-generic', version='#89-Ubuntu SMP PREEMPT_DYNAMIC Sat Oct 11 01:02:46 UTC 2025', machine='x86_64')
$ nvidia-smi
Wed Dec 3 23:16:26 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 On | Off |
| 30% 31C P2 45W / 450W | 21101MiB / 24564MiB | 9% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 3534 G /usr/lib/xorg/Xorg 1133MiB |
| 0 N/A N/A 3764 G /usr/bin/gnome-shell 229MiB |
| 0 N/A N/A 4620 G ...rack-uuid=3190708988185955192 432MiB |
| 0 N/A N/A 5179 G /opt/Signal/signal-desktop 159MiB |
| 0 N/A N/A 6047 G /proc/self/exe 305MiB |
| 0 N/A N/A 49945 G ...slack/216/usr/lib/slack/slack 109MiB |
| 0 N/A N/A 53401 C python 18496MiB |
+-----------------------------------------------------------------------------------------+
import jax.numpy as jnp
import jax
jax.config.update("jax_platform_name", "cpu")
print(jax.devices())
print(jnp.linalg.svd(jnp.full((1, 3, 3), jnp.nan)))
print(jnp.linalg.svd(jnp.full((2, 3, 3), jnp.nan)))
Output
[CpuDevice(id=0)]
SVDResult(U=Array([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], dtype=float32, weak_type=True), S=Array([[nan, nan, nan]], dtype=float32, weak_type=True), Vh=Array([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], dtype=float32, weak_type=True))
SVDResult(U=Array([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],
[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], dtype=float32, weak_type=True), S=Array([[nan, nan, nan],
[nan, nan, nan]], dtype=float32, weak_type=True), Vh=Array([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],
[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], dtype=float32, weak_type=True))
This inconsistency is due to the Iterative Solvers implemented on the GPU. On the CPU I assume it is being solved using standard Arithmetic Solvers (LAPACK).