jax
jax copied to clipboard
Apple Silicon: error: failed to legalize operation 'mhlo.triangular_solve'
Description
Matrix inversion appears to be broken on jax-metal. Apologies in advance if this is not the right place to report the issue.
Repro:
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'METAL')
A = jnp.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
B = jnp.linalg.inv(A)
print(B)
What jax/jaxlib version are you using?
jax 0.4.11, jaxlib 0.4.10
Which accelerator(s) are you using?
GPU/Metal
Additional system info
Mac OS 13.5.1, M1 Max
NVIDIA GPU info
No response
Just to add this and a plethora of other fundamental jax operations simply don't work on Jax-metal with no update insight. At this point, jax-metal is quite useless as far as I have been able to test it. As of version Jax-metal 0.0.4, I would say simply avoid it and install Jax on CPU.
Just FYI, it seems like Apple released a new version of metal-jax a few weeks back: version 0.0.5. This basic feature is still broken.
Same issue here. Any idea when a fix will be released?
I will ping this and add that I hope this gets fixed in the next release.
I think the JAX + Apple Silicon combo offers some unique advantages for prototyping models locally (i.e. smaller compute but larger RAM compared to a comparable NVIDIA workstation). I think JAX users will pick up on this pretty quickly once features like this are fixed.
We are looking into adding the conversion for the op.
I also hope this gets fixed soon! Do you happen to have any updates @shuhand0? Thanks