probability icon indicating copy to clipboard operation
probability copied to clipboard

MultivariateNormalFullCovariance gives false log_prob with JAX backend, GPU and fp64

Open dkn16 opened this issue 3 years ago • 0 comments

Code

from jax.config import config
config.update("jax_enable_x64", True)
import os
#os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions

Y=jnp.ones((4032,258),dtype=jnp.float64)
distribution = tfd.MultivariateNormalFullCovariance(loc = jnp.zeros((Y.shape[1])),covariance_matrix = jnp.eye(Y.shape[1],dtype = jnp.float64))
distribution.log_prob(Y)

Problem

When I enables fp64 for jax, with GPU, I can only calculate the log_prob of Y with size smaller than (4032,258). For example, size of (4096,258) would give false result:

DeviceArray([-366.08614157, -366.08614157, -366.08614157, ...,
                       nan,           nan,           nan], dtype=float64)

nan should be a false result. However with CPU everything works fine. I suppose this is the error in #1666

When i uncommented the os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform', I got an error:

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-1-6651013d38f8>](https://ajr892vy7y4-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230104-060047-RC00_499454183#) in <module>
     12 distribution = tfd.MultivariateNormalFullCovariance(loc = jnp.zeros((Y.shape[1])),covariance_matrix = jnp.eye(Y.shape[1],dtype = jnp.float64))
---> 13 distribution.log_prob(Y)
     14 #jnp.where(log_prob<0)

59 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to complete all kernels launched on stream 0x89279a0: Could not synchronize CUDA stream: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

XlaRuntimeError                           Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/jax/_src/scipy/linalg.py](https://ajr892vy7y4-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230104-060047-RC00_499454183#) in solve_triangular(***failed resolving arguments***)
    404                      debug: Any = None, check_finite: bool = True) -> Array:
    405   del overwrite_b, debug, check_finite  # unused
--> 406   return _solve_triangular(a, b, trans, lower, unit_diagonal)
    407 
    408 

XlaRuntimeError: INTERNAL: Failed to complete all kernels launched on stream 0x89279a0: Could not synchronize CUDA stream: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered

Material

A colab can be found here which reproduces the error.

dkn16 avatar Jan 06 '23 07:01 dkn16