jax icon indicating copy to clipboard operation
jax copied to clipboard

Apple Silicon: error: failed to legalize operation 'mhlo.cholesky'

Open adam-hartshorne opened this issue 2 years ago • 18 comments

Description

After building jaxlib as per the instructions and installing jax-metal, upon testing with an existing model which works fine using CPU (and GPU on linux), I get the following error.

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky' /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32>

The full error message is very low, and is attached here.

cholesky_full_error.txt.zip

I did try a minimal example shown below which also calls the cholesky operator, but I couldn't reproduce the same error. I am more than happy to try another more in-depth test code. Any suggestions?

from jax import jit
import jax.numpy as jnp
import jax.random as jnr
import jax.scipy as jsp

key = jnr.PRNGKey(0)
A = jnr.normal(key, (100,100))

def calc_cholesky_decomp(test_matrix):
    psd_test_matrix = test_matrix @ test_matrix.T
    col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
    return col_decomp

calc_cholesky_decomp(A)

jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
jitted_calc_cholesky_decomp(A)

What jax/jaxlib version are you using?

jaxlib 0.4.10 (metal), jax 0.4.11

Which accelerator(s) are you using?

CPU/GPU

Additional system info

Python v3.10.10, Apple M2

NVIDIA GPU info

No response

adam-hartshorne avatar Jun 08 '23 18:06 adam-hartshorne

@shuhand0 @kulinseth

hawkinsp avatar Jun 08 '23 20:06 hawkinsp

Can confirm that this is still broken in version 0.0.5

benjaminvatterj avatar Jan 01 '24 20:01 benjaminvatterj

Any update on ETA here? I am trying to use Brax on Metal and it wants the cholesky decomp.

@kulinseth

c0g avatar Mar 01 '24 18:03 c0g

Looking into add the conversion of the op.

shuhand0 avatar Mar 06 '24 20:03 shuhand0

I just wanted to mark that it's still not implemented in version 0.0.6 in case anyone noticed the new release

benjaminvatterj avatar Mar 13 '24 12:03 benjaminvatterj

I'm also eagerly awaiting this

vhaasteren avatar Mar 22 '24 07:03 vhaasteren

Would love to use multivariate normal distributions which depends on the Cholesky decomposition. Am eagerly awaiting this.

mvanaltvorst avatar Apr 04 '24 11:04 mvanaltvorst

Still not working in jax-metal v0.0.7

driesmarzougui avatar May 06 '24 08:05 driesmarzougui

We're approaching the one year mark on this. Any hope that this would be resolved soon?

benjaminvatterj avatar May 15 '24 13:05 benjaminvatterj

Is jax-metal open source? I can’t find it but would consider contributing.

c0g avatar May 15 '24 13:05 c0g

As far as I know its maintained by people at Apple (@kulinseth). I believe they don't share their code.

benjaminvatterj avatar May 15 '24 13:05 benjaminvatterj