jax icon indicating copy to clipboard operation
jax copied to clipboard

`jax.lax.random_gamma_grad` has very slow convergence for large inputs

Open hylkedonker opened this issue 9 months ago • 3 comments

Description

I am training a variational inference model that uses jax.random.dirichlet to make a Monte Carlo estimate of the ELBO (evidence lower bound). I found that on GPU, training sometimes get stuck. It turns out that specific values of alpha and shape causes jax.random.dirichlet to hang when part of jax.graded function. I've managed to track down a reproducible code snippet:

import jax
import jax.numpy as jnp

alpha = jnp.array(
[
 [1.0031196e+00, 1.0657588e+00, 1.0015138e+00, 9.6887946e-01, 1.0072964e+00,
  9.9730104e-01, 9.8444152e-01, 1.0215390e+00, 9.9271071e-01, 9.6725780e-01,
  1.0160147e+00, 9.9990100e-01],
 [1.0190175e+00, 1.0119619e+00, 1.0658907e+00, 1.0210572e+00, 9.9823701e-01,
  1.0059961e+00, 1.0150868e+00, 9.1839767e-01, 9.7111833e-01, 9.5713389e-01,
  1.0254700e+00, 9.3788880e-01],
 [9.6162099e-01, 9.9599779e-01, 1.0572764e+00, 1.0106628e+00, 1.0098387e+00,
  1.0022438e+00, 9.7249472e-01, 9.9937338e-01, 1.0402182e+00, 1.0575043e+00,
  1.0217358e+00, 1.0250369e+00],
 [1.0484207e+00, 9.8520476e-01, 1.0237277e+00, 1.0169791e+00, 1.0115385e+00,
  1.0200191e+00, 1.0063468e+00, 1.0190374e+00, 1.0228696e+00, 1.0135293e+00,
  1.0109347e+00, 9.8110938e-01],
 [9.7286355e-01, 9.8506188e-01, 9.8601210e-01, 1.0110497e+00, 9.9962467e-01,
  1.0079705e+00, 1.0254623e+00, 1.0033917e+00, 1.0134741e+00, 9.3710870e-01,
  1.0039001e+00, 1.0229504e+00],
 [2.0145073e+03, 2.0279155e+03, 2.0245532e+03, 2.0301467e+03, 2.0233922e+03,
  2.0119932e+03, 2.0074518e+03, 2.0082502e+03, 2.0344823e+03, 2.0195020e+03,
  2.7708268e+07, 2.0147581e+03]
]
)

def loss(params):
    key = jnp.array([2547407540, 2718371875], dtype=jnp.uint32)
    jax.random.dirichlet(key, params, shape=(64, 6))
    return 1.0

dloss_dparams = jax.grad(loss)
dloss_dparams(alpha)  # <= This statement hangs on GPU (but not CPU).

A few observations:

  • I only encounter the problem on GPU (system details below). The function works fine on CPU.
  • Setting export JAX_ENABLE_X64=True fixes the problem.
  • Slightly different alpha and shape is slow, but no longer hangs.

I suspect that some numerical under/overflow error in the gradient of the gamma distribution prevents a series expansion from terminating.

Let me know if there is anything else I can do to help clarify the problem.

System info (python version, jaxlib version, accelerator, etc.)

I've run the code snippet in an apptainer container that is based on the ghcr.io/nvidia/jax Docker image.

>>> import jax; jax.print_environment_info()

jax:    0.4.29.dev20240522
jaxlib: 0.4.29.dev20240514
numpy:  1.26.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='nb-node-b01', release='3.10.0-1160.108.1.el7.x86_64', version='#1 SMP Thu Jan 25 16:17:31 UTC 2024', machine='x86_64')


$ nvidia-smi
Wed May 22 08:46:42 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A40                     On  | 00000000:00:0F.0 Off |                    0 |
|  0%   34C    P0              28W / 300W |    273MiB / 46068MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     14470      C   python3                                     262MiB |
+---------------------------------------------------------------------------------------+

hylkedonker avatar May 22 '24 09:05 hylkedonker