probability icon indicating copy to clipboard operation
probability copied to clipboard

[Bug] Pivoted cholesky jit error when using jax backend

Open wjmaddox opened this issue 3 years ago • 1 comments

Hi,

I'm attempting to use the jax backend (hopefully this is the correct version of the call) to compute pivoted cholesky decompositions of PSD matrices, but am running into errors with the substrate code. Below is a reproducible example:

import jax

from jax.random import PRNGKey
from tensorflow_probability.substrates.jax.math.linalg import pivoted_cholesky

key = PRNGKey(seed=21)
A = jax.random.normal(key, (50, 100))
mat = A.T @ A # random 100 x 100 matrix

pivoted_cholesky(mat, 20) # fails

First, it produces a warning of the following command:

/jax/_src/numpy/lax_numpy.py:3584: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax._check_user_dtype_supported(dtype, "array")

before producing the following error (with accompanying stack trace):

~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/linalg.py in pivoted_cholesky(matrix, max_rank, diag_rtol, name)
    390     perm = tf.broadcast_to(
    391         ps.range(matrix_shape[-1]), matrix_shape[:-1])
--> 392     _, pchol, _, _ = tf.while_loop(
    393         cond=cond, body=body, loop_vars=(m, pchol, perm, matrix_diag))
    394     pchol = tf.linalg.matrix_transpose(pchol)

~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py in wrap(***failed resolving arguments***)
     60   def wrap(wrapped, instance, args, kwargs):
     61     del instance, wrapped
---> 62     return new_fn(*args, **kwargs)
     63   return wrap(original_fn)  # pylint: disable=no-value-for-parameter
     64 

~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py in _while_loop_jax(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
     88     def override_cond_fn(args):
     89       return cond(*args)
---> 90     return lax.while_loop(override_cond_fn, override_body_fn, loop_vars)
     91   elif back_prop:
     92     def override_body_fn(args, _):

    [... skipping hidden 13 frame]

~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py in override_body_fn(args)
     85   if maximum_iterations is None:
     86     def override_body_fn(args):
---> 87       return pack_body(body(*args))
     88     def override_cond_fn(args):
     89       return cond(*args)

~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/linalg.py in body(m, pchol, perm, matrix_diag)
    343       # Find the maximal position of the (remaining) permuted diagonal.
    344       # Steps 1, 2 above.
--> 345       permuted_diag = batch_gather(matrix_diag, perm[..., m:])
    346       maxi = tf.argmax(
    347           permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis]

    [... skipping hidden 1 frame]

~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   5702   arr = asarray(arr)
   5703   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 5704   return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   5705                  unique_indices, mode, fill_value)
   5706 

~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   5711             unique_indices, mode, fill_value):
   5712   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 5713   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   5714   y = arr
   5715 

~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _index_to_gather(x_shape, idx, normalize_indices)
   5954                  "dynamic_update_slice (JAX does not support dynamically sized "
   5955                  "arrays within JIT compiled functions).")
-> 5956           raise IndexError(msg)
   5957         if not core.is_constant_dim(x_shape[x_axis]):
   5958           msg = ("Cannot use NumPy slice indexing on an array dimension whose "

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

Is there a simple way to fix this? I know that a good workaround is to not use the backend explicitly but then it produces a tensorflow output that i have to pass back into jax:

import tensorflow_probability as tfp

jnp.array(tfp.math.pivoted_cholesky(mat, 20)) # correctly runs

System information: jax: '0.2.26' tfp: '0.15.0'

wjmaddox avatar Jan 08 '22 23:01 wjmaddox

This is going to be tricky to fix since the algorithm uses dynamic shape within a while loop, something that JAX jit doesn't like. We could either fix the algorithm to use masking, or unroll the while loop. The latter will be quite inefficient for the large matrices you're using, but the former will require far more complex coding.

SiegeLordEx avatar Jan 10 '22 20:01 SiegeLordEx