probability
probability copied to clipboard
MultivariateNormalFullCovariance gives false log_prob with JAX backend, GPU and fp64
Code
from jax.config import config
config.update("jax_enable_x64", True)
import os
#os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
Y=jnp.ones((4032,258),dtype=jnp.float64)
distribution = tfd.MultivariateNormalFullCovariance(loc = jnp.zeros((Y.shape[1])),covariance_matrix = jnp.eye(Y.shape[1],dtype = jnp.float64))
distribution.log_prob(Y)
Problem
When I enables fp64 for jax, with GPU, I can only calculate the log_prob of Y with size smaller than (4032,258). For example, size of (4096,258) would give false result:
DeviceArray([-366.08614157, -366.08614157, -366.08614157, ...,
nan, nan, nan], dtype=float64)
nan should be a false result. However with CPU everything works fine. I suppose this is the error in #1666
When i uncommented the os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform', I got an error:
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-1-6651013d38f8>](https://ajr892vy7y4-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230104-060047-RC00_499454183#) in <module>
12 distribution = tfd.MultivariateNormalFullCovariance(loc = jnp.zeros((Y.shape[1])),covariance_matrix = jnp.eye(Y.shape[1],dtype = jnp.float64))
---> 13 distribution.log_prob(Y)
14 #jnp.where(log_prob<0)
59 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to complete all kernels launched on stream 0x89279a0: Could not synchronize CUDA stream: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
XlaRuntimeError Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/jax/_src/scipy/linalg.py](https://ajr892vy7y4-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230104-060047-RC00_499454183#) in solve_triangular(***failed resolving arguments***)
404 debug: Any = None, check_finite: bool = True) -> Array:
405 del overwrite_b, debug, check_finite # unused
--> 406 return _solve_triangular(a, b, trans, lower, unit_diagonal)
407
408
XlaRuntimeError: INTERNAL: Failed to complete all kernels launched on stream 0x89279a0: Could not synchronize CUDA stream: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Material
A colab can be found here which reproduces the error.