jax
jax copied to clipboard
Apple Silicon: error: failed to legalize operation 'mhlo.cholesky'
Description
After building jaxlib as per the instructions and installing jax-metal, upon testing with an existing model which works fine using CPU (and GPU on linux), I get the following error.
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky' /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32>
The full error message is very low, and is attached here.
I did try a minimal example shown below which also calls the cholesky operator, but I couldn't reproduce the same error. I am more than happy to try another more in-depth test code. Any suggestions?
from jax import jit
import jax.numpy as jnp
import jax.random as jnr
import jax.scipy as jsp
key = jnr.PRNGKey(0)
A = jnr.normal(key, (100,100))
def calc_cholesky_decomp(test_matrix):
psd_test_matrix = test_matrix @ test_matrix.T
col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
return col_decomp
calc_cholesky_decomp(A)
jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
jitted_calc_cholesky_decomp(A)
What jax/jaxlib version are you using?
jaxlib 0.4.10 (metal), jax 0.4.11
Which accelerator(s) are you using?
CPU/GPU
Additional system info
Python v3.10.10, Apple M2
NVIDIA GPU info
No response
@shuhand0 @kulinseth
Can confirm that this is still broken in version 0.0.5
Any update on ETA here? I am trying to use Brax on Metal and it wants the cholesky decomp.
@kulinseth
Looking into add the conversion of the op.
I just wanted to mark that it's still not implemented in version 0.0.6 in case anyone noticed the new release
I'm also eagerly awaiting this
Would love to use multivariate normal distributions which depends on the Cholesky decomposition. Am eagerly awaiting this.
Still not working in jax-metal v0.0.7
We're approaching the one year mark on this. Any hope that this would be resolved soon?
Is jax-metal open source? I can’t find it but would consider contributing.
As far as I know its maintained by people at Apple (@kulinseth). I believe they don't share their code.