jax
jax copied to clipboard
Metal plugin - failed to legalize operation 'mhlo.convolution'
Description
Encountered this on Apple M2 Pro after following instructions from https://developer.apple.com/metal/jax/, and then trying to get https://github.com/sanchit-gandhi/whisper-jax to run.
Steps for reproducing:
- Compile and install metal jax following instructions from: https://developer.apple.com/metal/jax/
- Install whisper-jax with:
pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
from whisper_jax import FlaxWhisperPipline
# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
# JIT compile the forward call - slow, but we only do once
text = pipeline("audio.mp3")
Leads to the following error:
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError:
UNKNOWN: /Users/pere/jax-metal/lib/python3.10/site-packages/whisper_jax/layers.py:1236:0:
error: failed to legalize operation 'mhlo.convolution'
/Users/pere/jax-metal/lib/python3.10/site-packages/whisper_jax/layers.py:1236:0: note: see current operation:
%111 = "mhlo.convolution"(%110, <<UNKNOWN SSA VALUE>>)
{batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<1> :
tensor<1xi64>, padding = dense<1> :
tensor<1x2xi64>, precision_config = [#mhlo<precision DEFAULT>,
#mhlo<precision DEFAULT>],
rhs_dilation = dense<1> :
tensor<1xi64>, window_reversal = dense<false> :
tensor<1xi1>,
window_strides = dense<1> : tensor<1xi64>} : (tensor<1x3000x80xf32>, tensor<3x80x384xf32>) -> tensor<1x3000x384xf32>
Maybe this operation isnt implemented yet? @skye @kulinseth @jyingl3.
What jax/jaxlib version are you using?
jaxlib-v0.4.10
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response
@kulinseth
Thanks @peregilk for filing the issue, we will take a look. cc @sding23
This is a case of the dimension_numbers we didn't support. We will look into expanding our conversion patterns for convolution op.
By the way, this also happens when you try to train a very simple neural net with flax
, e.g., the one in the tutorial:
https://github.com/google/flax/tree/main/examples/mnist/
@shuhand0 Thanks. Any rough estimates on a time frame here?
Do you know if there is any workaround?
Thanks!
Maybe I am wrong here, but for me this looks like lacking support for a standard 1d conv layers. Is this correct? If so, this is blocking a lot of different use cases. For me it blocks a very interesting use of Whisper on Mac.
I have no idea of the complexity of implementing this, but is there any chance it could be prioritised?
@kulinseth @sding23 @skye
Has this been abandoned? @kulinseth Can I help? Quite the deal breaker that we cannot invert matrices, do 3d convolutions, eigenvector decomposition, etc.
@kulingseth - Could you give an update on this? I am still unable to use jax-metal to run whisper-jax, even using v0.0.5. What of the needed operations is expected in the near future?
The issue is lacking of conv1d support. Will look into adding a conversion sequence to shape it to conv2d.
The issue should be fixed in jax-metal 0.0.7. Pls reopen it if otherwise.