jax icon indicating copy to clipboard operation
jax copied to clipboard

SVD does not propagate NaNs for batch sizes >2

Open amacati opened this issue 1 month ago • 1 comments

Description

Running an SVD on a matrix full of NaNs is expected to return NaNs in U, S and Vt, but only S consistently propagates NaNs. When running an SVD on a matrix of (2, 3, 3) full of NaNs, U unexpectedly becomes a batch of two eye(3) matrices, and Vt is a batch of 2 -eye(3) matrices.

import jax.numpy as jnp

print(jnp.linalg.svd(jnp.full((1, 3, 3), jnp.nan)))  # All NaN as expected
print(jnp.linalg.svd(jnp.full((2, 3, 3), jnp.nan)))  # U and Vt contain eye and -eye matrices, only S is NaN

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.8.1
jaxlib: 0.8.1
numpy:  2.3.5
python: 3.13.5 | packaged by conda-forge | (main, Jun 16 2025, 08:27:50) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='amacati-workstation', release='6.8.0-88-generic', version='#89-Ubuntu SMP PREEMPT_DYNAMIC Sat Oct 11 01:02:46 UTC 2025', machine='x86_64')

$ nvidia-smi
Wed Dec  3 23:16:26 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0  On |                  Off |
| 30%   31C    P2             45W /  450W |   21101MiB /  24564MiB |      9%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            3534      G   /usr/lib/xorg/Xorg                     1133MiB |
|    0   N/A  N/A            3764      G   /usr/bin/gnome-shell                    229MiB |
|    0   N/A  N/A            4620      G   ...rack-uuid=3190708988185955192        432MiB |
|    0   N/A  N/A            5179      G   /opt/Signal/signal-desktop              159MiB |
|    0   N/A  N/A            6047      G   /proc/self/exe                          305MiB |
|    0   N/A  N/A           49945      G   ...slack/216/usr/lib/slack/slack        109MiB |
|    0   N/A  N/A           53401      C   python                                18496MiB |
+-----------------------------------------------------------------------------------------+

amacati avatar Dec 03 '25 22:12 amacati

import jax.numpy as jnp
import jax

jax.config.update("jax_platform_name", "cpu")
print(jax.devices())

print(jnp.linalg.svd(jnp.full((1, 3, 3), jnp.nan)))  
print(jnp.linalg.svd(jnp.full((2, 3, 3), jnp.nan)))  

Output

[CpuDevice(id=0)]
SVDResult(U=Array([[[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]]], dtype=float32, weak_type=True), S=Array([[nan, nan, nan]], dtype=float32, weak_type=True), Vh=Array([[[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]]], dtype=float32, weak_type=True))
SVDResult(U=Array([[[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]],

       [[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]]], dtype=float32, weak_type=True), S=Array([[nan, nan, nan],
       [nan, nan, nan]], dtype=float32, weak_type=True), Vh=Array([[[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]],

       [[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]]], dtype=float32, weak_type=True))

This inconsistency is due to the Iterative Solvers implemented on the GPU. On the CPU I assume it is being solved using standard Arithmetic Solvers (LAPACK).

SuriyaaMM avatar Dec 09 '25 07:12 SuriyaaMM