jax icon indicating copy to clipboard operation
jax copied to clipboard

Apple Silicon: error: failed to legalize operation 'mhlo.triangular_solve'

Open yu-fz opened this issue 1 year ago • 9 comments

Description

Matrix inversion appears to be broken on jax-metal. Apologies in advance if this is not the right place to report the issue.

Repro:

import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'METAL')

A = jnp.array([[1, 2, 3],
               [4, 5, 6],
               [7, 8, 9]])
B = jnp.linalg.inv(A)
print(B)

What jax/jaxlib version are you using?

jax 0.4.11, jaxlib 0.4.10

Which accelerator(s) are you using?

GPU/Metal

Additional system info

Mac OS 13.5.1, M1 Max

NVIDIA GPU info

No response

yu-fz avatar Sep 07 '23 09:09 yu-fz

Would be a +1 for optics if assignees could 👍 to acknowledge receipt of the task. There's a handful of bugs around core operations with no indication of progress nor what's blocking them. Guessing it's Apple dragging their heels; do they not have the firepower to provide (uniquely) a software-stack to support their own hardware that's 1+ years old now? Seems cringe that the feeling of UX disappointment should fall unfairly upon JAX.

p-i- avatar Sep 22 '23 21:09 p-i-

Just to add this and a plethora of other fundamental jax operations simply don't work on Jax-metal with no update insight. At this point, jax-metal is quite useless as far as I have been able to test it. As of version Jax-metal 0.0.4, I would say simply avoid it and install Jax on CPU.

benjaminvatterj avatar Dec 12 '23 20:12 benjaminvatterj

Just FYI, it seems like Apple released a new version of metal-jax a few weeks back: version 0.0.5. This basic feature is still broken.

benjaminvatterj avatar Jan 01 '24 19:01 benjaminvatterj

Same issue here. Any idea when a fix will be released?

crondonm avatar Feb 16 '24 14:02 crondonm

I will ping this and add that I hope this gets fixed in the next release.

I think the JAX + Apple Silicon combo offers some unique advantages for prototyping models locally (i.e. smaller compute but larger RAM compared to a comparable NVIDIA workstation). I think JAX users will pick up on this pretty quickly once features like this are fixed.

nngabe avatar Mar 05 '24 22:03 nngabe

We are looking into adding the conversion for the op.

shuhand0 avatar Mar 06 '24 20:03 shuhand0

I also hope this gets fixed soon! Do you happen to have any updates @shuhand0? Thanks

mvanaltvorst avatar Apr 24 '24 09:04 mvanaltvorst