jax
jax copied to clipboard
`jax.lax.random_gamma_grad` has very slow convergence for large inputs
Description
I am training a variational inference model that uses jax.random.dirichlet to make a Monte Carlo estimate of the ELBO (evidence lower bound). I found that on GPU, training sometimes get stuck. It turns out that specific values of alpha and shape causes jax.random.dirichlet to hang when part of jax.graded function.
I've managed to track down a reproducible code snippet:
import jax
import jax.numpy as jnp
alpha = jnp.array(
[
[1.0031196e+00, 1.0657588e+00, 1.0015138e+00, 9.6887946e-01, 1.0072964e+00,
9.9730104e-01, 9.8444152e-01, 1.0215390e+00, 9.9271071e-01, 9.6725780e-01,
1.0160147e+00, 9.9990100e-01],
[1.0190175e+00, 1.0119619e+00, 1.0658907e+00, 1.0210572e+00, 9.9823701e-01,
1.0059961e+00, 1.0150868e+00, 9.1839767e-01, 9.7111833e-01, 9.5713389e-01,
1.0254700e+00, 9.3788880e-01],
[9.6162099e-01, 9.9599779e-01, 1.0572764e+00, 1.0106628e+00, 1.0098387e+00,
1.0022438e+00, 9.7249472e-01, 9.9937338e-01, 1.0402182e+00, 1.0575043e+00,
1.0217358e+00, 1.0250369e+00],
[1.0484207e+00, 9.8520476e-01, 1.0237277e+00, 1.0169791e+00, 1.0115385e+00,
1.0200191e+00, 1.0063468e+00, 1.0190374e+00, 1.0228696e+00, 1.0135293e+00,
1.0109347e+00, 9.8110938e-01],
[9.7286355e-01, 9.8506188e-01, 9.8601210e-01, 1.0110497e+00, 9.9962467e-01,
1.0079705e+00, 1.0254623e+00, 1.0033917e+00, 1.0134741e+00, 9.3710870e-01,
1.0039001e+00, 1.0229504e+00],
[2.0145073e+03, 2.0279155e+03, 2.0245532e+03, 2.0301467e+03, 2.0233922e+03,
2.0119932e+03, 2.0074518e+03, 2.0082502e+03, 2.0344823e+03, 2.0195020e+03,
2.7708268e+07, 2.0147581e+03]
]
)
def loss(params):
key = jnp.array([2547407540, 2718371875], dtype=jnp.uint32)
jax.random.dirichlet(key, params, shape=(64, 6))
return 1.0
dloss_dparams = jax.grad(loss)
dloss_dparams(alpha) # <= This statement hangs on GPU (but not CPU).
A few observations:
- I only encounter the problem on GPU (system details below). The function works fine on CPU.
- Setting
export JAX_ENABLE_X64=Truefixes the problem. - Slightly different
alphaandshapeis slow, but no longer hangs.
I suspect that some numerical under/overflow error in the gradient of the gamma distribution prevents a series expansion from terminating.
Let me know if there is anything else I can do to help clarify the problem.
System info (python version, jaxlib version, accelerator, etc.)
I've run the code snippet in an apptainer container that is based on the ghcr.io/nvidia/jax Docker image.
>>> import jax; jax.print_environment_info()
jax: 0.4.29.dev20240522
jaxlib: 0.4.29.dev20240514
numpy: 1.26.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='nb-node-b01', release='3.10.0-1160.108.1.el7.x86_64', version='#1 SMP Thu Jan 25 16:17:31 UTC 2024', machine='x86_64')
$ nvidia-smi
Wed May 22 08:46:42 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A40 On | 00000000:00:0F.0 Off | 0 |
| 0% 34C P0 28W / 300W | 273MiB / 46068MiB | 2% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 14470 C python3 262MiB |
+---------------------------------------------------------------------------------------+
Thanks for the report. I was able to reduce the problem to this:
import jax
jax.lax.random_gamma_grad(27708268.0, 27708266.0)
These values are being generated and passed to gamma_grad in your code, and these are the inputs that cause a hang on GPU but not CPU.
If you look at the source of random_gamma_grad_impl, it's implemented in terms of a while_loop, and so something about these values is leading to non-convergence of the while loop on GPU.
Investigating further, it looks like the non-convergence occurs whenever both inputs are above 2 ** 24, which is (probably not coincidentally) the number of significant bits in the mantissa of float32.
Nice catch!