jax icon indicating copy to clipboard operation
jax copied to clipboard

Metal plugin - failed to legalize operation 'mhlo.convolution'

Open peregilk opened this issue 1 year ago • 10 comments

Description

Encountered this on Apple M2 Pro after following instructions from https://developer.apple.com/metal/jax/, and then trying to get https://github.com/sanchit-gandhi/whisper-jax to run.

Steps for reproducing:

  • Compile and install metal jax following instructions from: https://developer.apple.com/metal/jax/
  • Install whisper-jax with: pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
from whisper_jax import FlaxWhisperPipline

# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-large-v2")

# JIT compile the forward call - slow, but we only do once
text = pipeline("audio.mp3")

Leads to the following error:

jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: 
UNKNOWN: /Users/pere/jax-metal/lib/python3.10/site-packages/whisper_jax/layers.py:1236:0: 
error: failed to legalize operation 'mhlo.convolution'

/Users/pere/jax-metal/lib/python3.10/site-packages/whisper_jax/layers.py:1236:0: note: see current operation: 
%111 = "mhlo.convolution"(%110, <<UNKNOWN SSA VALUE>>) 
{batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<1> : 
tensor<1xi64>, padding = dense<1> :
 tensor<1x2xi64>, precision_config = [#mhlo<precision DEFAULT>, 
#mhlo<precision DEFAULT>], 
rhs_dilation = dense<1> : 
tensor<1xi64>, window_reversal = dense<false> : 
tensor<1xi1>, 
window_strides = dense<1> : tensor<1xi64>} : (tensor<1x3000x80xf32>, tensor<3x80x384xf32>) -> tensor<1x3000x384xf32>

Maybe this operation isnt implemented yet? @skye @kulinseth @jyingl3.

What jax/jaxlib version are you using?

jaxlib-v0.4.10

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

peregilk avatar Jun 07 '23 20:06 peregilk

@kulinseth

hawkinsp avatar Jun 07 '23 21:06 hawkinsp

Thanks @peregilk for filing the issue, we will take a look. cc @sding23

kulinseth avatar Jun 08 '23 05:06 kulinseth

This is a case of the dimension_numbers we didn't support. We will look into expanding our conversion patterns for convolution op.

shuhand0 avatar Jun 08 '23 15:06 shuhand0

By the way, this also happens when you try to train a very simple neural net with flax, e.g., the one in the tutorial: https://github.com/google/flax/tree/main/examples/mnist/

mlaves avatar Jun 12 '23 20:06 mlaves

@shuhand0 Thanks. Any rough estimates on a time frame here?

peregilk avatar Jun 20 '23 12:06 peregilk

Do you know if there is any workaround?

Thanks!

marcoacierno avatar Aug 14 '23 17:08 marcoacierno

Maybe I am wrong here, but for me this looks like lacking support for a standard 1d conv layers. Is this correct? If so, this is blocking a lot of different use cases. For me it blocks a very interesting use of Whisper on Mac.

I have no idea of the complexity of implementing this, but is there any chance it could be prioritised?

@kulinseth @sding23 @skye

peregilk avatar Aug 17 '23 16:08 peregilk

Has this been abandoned? @kulinseth Can I help? Quite the deal breaker that we cannot invert matrices, do 3d convolutions, eigenvector decomposition, etc.

syrkis avatar Nov 24 '23 22:11 syrkis

@kulingseth - Could you give an update on this? I am still unable to use jax-metal to run whisper-jax, even using v0.0.5. What of the needed operations is expected in the near future?

peregilk avatar Feb 18 '24 07:02 peregilk

The issue is lacking of conv1d support. Will look into adding a conversion sequence to shape it to conv2d.

shuhand0 avatar Mar 12 '24 21:03 shuhand0

The issue should be fixed in jax-metal 0.0.7. Pls reopen it if otherwise.

shuhand0 avatar May 28 '24 18:05 shuhand0