probability
probability copied to clipboard
[Bug] Pivoted cholesky jit error when using jax backend
Hi,
I'm attempting to use the jax backend (hopefully this is the correct version of the call) to compute pivoted cholesky decompositions of PSD matrices, but am running into errors with the substrate code. Below is a reproducible example:
import jax
from jax.random import PRNGKey
from tensorflow_probability.substrates.jax.math.linalg import pivoted_cholesky
key = PRNGKey(seed=21)
A = jax.random.normal(key, (50, 100))
mat = A.T @ A # random 100 x 100 matrix
pivoted_cholesky(mat, 20) # fails
First, it produces a warning of the following command:
/jax/_src/numpy/lax_numpy.py:3584: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "array")
before producing the following error (with accompanying stack trace):
~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/linalg.py in pivoted_cholesky(matrix, max_rank, diag_rtol, name)
390 perm = tf.broadcast_to(
391 ps.range(matrix_shape[-1]), matrix_shape[:-1])
--> 392 _, pchol, _, _ = tf.while_loop(
393 cond=cond, body=body, loop_vars=(m, pchol, perm, matrix_diag))
394 pchol = tf.linalg.matrix_transpose(pchol)
~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py in wrap(***failed resolving arguments***)
60 def wrap(wrapped, instance, args, kwargs):
61 del instance, wrapped
---> 62 return new_fn(*args, **kwargs)
63 return wrap(original_fn) # pylint: disable=no-value-for-parameter
64
~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py in _while_loop_jax(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
88 def override_cond_fn(args):
89 return cond(*args)
---> 90 return lax.while_loop(override_cond_fn, override_body_fn, loop_vars)
91 elif back_prop:
92 def override_body_fn(args, _):
[... skipping hidden 13 frame]
~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py in override_body_fn(args)
85 if maximum_iterations is None:
86 def override_body_fn(args):
---> 87 return pack_body(body(*args))
88 def override_cond_fn(args):
89 return cond(*args)
~/miniconda3/envs/jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/linalg.py in body(m, pchol, perm, matrix_diag)
343 # Find the maximal position of the (remaining) permuted diagonal.
344 # Steps 1, 2 above.
--> 345 permuted_diag = batch_gather(matrix_diag, perm[..., m:])
346 maxi = tf.argmax(
347 permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis]
[... skipping hidden 1 frame]
~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
5702 arr = asarray(arr)
5703 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 5704 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
5705 unique_indices, mode, fill_value)
5706
~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
5711 unique_indices, mode, fill_value):
5712 idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 5713 indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
5714 y = arr
5715
~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _index_to_gather(x_shape, idx, normalize_indices)
5954 "dynamic_update_slice (JAX does not support dynamically sized "
5955 "arrays within JIT compiled functions).")
-> 5956 raise IndexError(msg)
5957 if not core.is_constant_dim(x_shape[x_axis]):
5958 msg = ("Cannot use NumPy slice indexing on an array dimension whose "
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Is there a simple way to fix this? I know that a good workaround is to not use the backend explicitly but then it produces a tensorflow output that i have to pass back into jax:
import tensorflow_probability as tfp
jnp.array(tfp.math.pivoted_cholesky(mat, 20)) # correctly runs
System information: jax: '0.2.26' tfp: '0.15.0'
This is going to be tricky to fix since the algorithm uses dynamic shape within a while loop, something that JAX jit doesn't like. We could either fix the algorithm to use masking, or unroll the while loop. The latter will be quite inefficient for the large matrices you're using, but the former will require far more complex coding.