jax icon indicating copy to clipboard operation
jax copied to clipboard

jax-metal, M1: failed to materialize conversion for result #0 of operation 'mhlo.convolution'

Open a-gn opened this issue 1 year ago • 2 comments

Description

import jax.lax as lax
import jax.numpy as jnp

lax.conv_general_dilated(
    jnp.ones((2, 4, 8, 8)),
    jnp.ones((4, 4, 3, 3)),
    window_strides=(1, 1),
    lhs_dilation=(2, 2),
    padding=((1, 1), (1, 1)),
)

On NVIDIA GPUs, the same operation runs to the end. On M1, it crashes with this log:

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-01-15 21:07:09.252402: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

Traceback (most recent call last):
  File "/Users/arno/projects/jmt/src/reproduce jax bug.py", line 4, in <module>
    lax.conv_general_dilated(
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/lax/convolution.py", line 156, in conv_general_dilated
    return conv_general_dilated_p.bind(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 385, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 388, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 868, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 128, in apply_primitive
    compiled_fun = xla_primitive_callable(
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
    return cached(config.config._trace_context(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 161, in xla_primitive_callable
    compiled = computation.compile()
               ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2258, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2606, in from_hlo
    xla_executable, compile_options = _cached_compilation(
                                      ^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2513, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/compiler.py", line 295, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/profiler.py", line 340, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/compiler.py", line 255, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/arno/projects/jmt/src/reproduce jax bug.py:4:0: error: failed to materialize conversion for result #0 of operation 'mhlo.convolution' that remained live after conversion
/Users/arno/projects/jmt/src/reproduce jax bug.py:4:0: note: see current operation: %1 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1]>, feature_group_count = 1 : i64, lhs_dilation = dense<2> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense<false> : tensor<2xi1>, window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x8x8xf32>, tensor<4x4x3x3xf32>) -> tensor<2x4x15x15xf32>
<unknown>:0: note: see existing live user here: func.return %1 : tensor<2x4x15x15xf32>

What jax/jaxlib version are you using?

jax 0.4.20, jax-metal 0.0.5, jaxlib 0.4.20

Which accelerator(s) are you using?

M1 GPU

Additional system info?

1.26.2 3.11.5 (main, Aug 24 2023, 15:09:45) [Clang 14.0.3 (clang-1403.0.22.14.1)] uname_result(system='Darwin', node='mba-2.local', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:34 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T8103', machine='arm64')

NVIDIA GPU info

No response

a-gn avatar Jan 15 '24 20:01 a-gn

I'm also running into this problem. Here's my system information:

% python3 -c 'import jax; jax.print_environment_info()'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-15 11:09:00.577459: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max

systemMemory: 48.00 GB
maxCacheSize: 18.00 GB

jax:    0.4.25
jaxlib: 0.4.23
numpy:  1.26.4
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:54:21) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Connors-MacBook-Pro.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6031', machine='arm64')

This is with jax-metal version 0.0.6.

ConnorBaker avatar Mar 15 '24 18:03 ConnorBaker

We are aware of the issue and will integrate the solution in next release.

shuhand0 avatar Mar 15 '24 21:03 shuhand0

I noticed jax-metal 0.0.7 is out. Does it address this issue?

mikazlopes avatar May 04 '24 17:05 mikazlopes

I noticed jax-metal 0.0.7 is out. Does it address this issue?

It does :)

Python 3.11.9 (main, Apr  2 2024, 08:25:04) [Clang 15.0.0 (clang-1500.3.9.4)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.lax as lax
>>> import jax.numpy as jnp
>>> 
>>> lax.conv_general_dilated(
...     jnp.ones((2, 4, 8, 8)),
...     jnp.ones((4, 4, 3, 3)),
...     window_strides=(1, 1),
...     lhs_dilation=(2, 2),
...     padding=((1, 1), (1, 1)),
... ).shape
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1714899046.272198 26673924 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1714899046.314145 26673924 service.cc:145] XLA service 0x60000309df00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1714899046.314172 26673924 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1714899046.315841 26673924 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1714899046.315851 26673924 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
(2, 4, 15, 15)

a-gn avatar May 05 '24 08:05 a-gn