jax
jax copied to clipboard
`jax.lax.random_gamma_grad` has very slow convergence for large inputs
Description
I am training a variational inference model that uses jax.random.dirichlet
to make a Monte Carlo estimate of the ELBO (evidence lower bound). I found that on GPU, training sometimes get stuck. It turns out that specific values of alpha
and shape
causes jax.random.dirichlet
to hang when part of jax.grad
ed function.
I've managed to track down a reproducible code snippet:
import jax
import jax.numpy as jnp
alpha = jnp.array(
[
[1.0031196e+00, 1.0657588e+00, 1.0015138e+00, 9.6887946e-01, 1.0072964e+00,
9.9730104e-01, 9.8444152e-01, 1.0215390e+00, 9.9271071e-01, 9.6725780e-01,
1.0160147e+00, 9.9990100e-01],
[1.0190175e+00, 1.0119619e+00, 1.0658907e+00, 1.0210572e+00, 9.9823701e-01,
1.0059961e+00, 1.0150868e+00, 9.1839767e-01, 9.7111833e-01, 9.5713389e-01,
1.0254700e+00, 9.3788880e-01],
[9.6162099e-01, 9.9599779e-01, 1.0572764e+00, 1.0106628e+00, 1.0098387e+00,
1.0022438e+00, 9.7249472e-01, 9.9937338e-01, 1.0402182e+00, 1.0575043e+00,
1.0217358e+00, 1.0250369e+00],
[1.0484207e+00, 9.8520476e-01, 1.0237277e+00, 1.0169791e+00, 1.0115385e+00,
1.0200191e+00, 1.0063468e+00, 1.0190374e+00, 1.0228696e+00, 1.0135293e+00,
1.0109347e+00, 9.8110938e-01],
[9.7286355e-01, 9.8506188e-01, 9.8601210e-01, 1.0110497e+00, 9.9962467e-01,
1.0079705e+00, 1.0254623e+00, 1.0033917e+00, 1.0134741e+00, 9.3710870e-01,
1.0039001e+00, 1.0229504e+00],
[2.0145073e+03, 2.0279155e+03, 2.0245532e+03, 2.0301467e+03, 2.0233922e+03,
2.0119932e+03, 2.0074518e+03, 2.0082502e+03, 2.0344823e+03, 2.0195020e+03,
2.7708268e+07, 2.0147581e+03]
]
)
def loss(params):
key = jnp.array([2547407540, 2718371875], dtype=jnp.uint32)
jax.random.dirichlet(key, params, shape=(64, 6))
return 1.0
dloss_dparams = jax.grad(loss)
dloss_dparams(alpha) # <= This statement hangs on GPU (but not CPU).
A few observations:
- I only encounter the problem on GPU (system details below). The function works fine on CPU.
- Setting
export JAX_ENABLE_X64=True
fixes the problem. - Slightly different
alpha
andshape
is slow, but no longer hangs.
I suspect that some numerical under/overflow error in the gradient of the gamma distribution prevents a series expansion from terminating.
Let me know if there is anything else I can do to help clarify the problem.
System info (python version, jaxlib version, accelerator, etc.)
I've run the code snippet in an apptainer container that is based on the ghcr.io/nvidia/jax
Docker image.
>>> import jax; jax.print_environment_info()
jax: 0.4.29.dev20240522
jaxlib: 0.4.29.dev20240514
numpy: 1.26.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='nb-node-b01', release='3.10.0-1160.108.1.el7.x86_64', version='#1 SMP Thu Jan 25 16:17:31 UTC 2024', machine='x86_64')
$ nvidia-smi
Wed May 22 08:46:42 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A40 On | 00000000:00:0F.0 Off | 0 |
| 0% 34C P0 28W / 300W | 273MiB / 46068MiB | 2% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 14470 C python3 262MiB |
+---------------------------------------------------------------------------------------+