jax
jax copied to clipboard
jax-metal, M1: failed to materialize conversion for result #0 of operation 'mhlo.convolution'
Description
import jax.lax as lax
import jax.numpy as jnp
lax.conv_general_dilated(
jnp.ones((2, 4, 8, 8)),
jnp.ones((4, 4, 3, 3)),
window_strides=(1, 1),
lhs_dilation=(2, 2),
padding=((1, 1), (1, 1)),
)
On NVIDIA GPUs, the same operation runs to the end. On M1, it crashes with this log:
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-01-15 21:07:09.252402: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
Traceback (most recent call last):
File "/Users/arno/projects/jmt/src/reproduce jax bug.py", line 4, in <module>
lax.conv_general_dilated(
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/lax/convolution.py", line 156, in conv_general_dilated
return conv_general_dilated_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 385, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 388, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 868, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 128, in apply_primitive
compiled_fun = xla_primitive_callable(
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
return cached(config.config._trace_context(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 161, in xla_primitive_callable
compiled = computation.compile()
^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2258, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2606, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2513, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/compiler.py", line 295, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/profiler.py", line 340, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/jax-metal/lib/python3.11/site-packages/jax/_src/compiler.py", line 255, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/arno/projects/jmt/src/reproduce jax bug.py:4:0: error: failed to materialize conversion for result #0 of operation 'mhlo.convolution' that remained live after conversion
/Users/arno/projects/jmt/src/reproduce jax bug.py:4:0: note: see current operation: %1 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1]>, feature_group_count = 1 : i64, lhs_dilation = dense<2> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense<false> : tensor<2xi1>, window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x8x8xf32>, tensor<4x4x3x3xf32>) -> tensor<2x4x15x15xf32>
<unknown>:0: note: see existing live user here: func.return %1 : tensor<2x4x15x15xf32>
What jax/jaxlib version are you using?
jax 0.4.20, jax-metal 0.0.5, jaxlib 0.4.20
Which accelerator(s) are you using?
M1 GPU
Additional system info?
1.26.2 3.11.5 (main, Aug 24 2023, 15:09:45) [Clang 14.0.3 (clang-1403.0.22.14.1)] uname_result(system='Darwin', node='mba-2.local', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:34 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T8103', machine='arm64')
NVIDIA GPU info
No response
I'm also running into this problem. Here's my system information:
% python3 -c 'import jax; jax.print_environment_info()'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-15 11:09:00.577459: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max
systemMemory: 48.00 GB
maxCacheSize: 18.00 GB
jax: 0.4.25
jaxlib: 0.4.23
numpy: 1.26.4
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:54:21) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Connors-MacBook-Pro.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6031', machine='arm64')
This is with jax-metal
version 0.0.6.
We are aware of the issue and will integrate the solution in next release.
I noticed jax-metal 0.0.7 is out. Does it address this issue?
I noticed jax-metal 0.0.7 is out. Does it address this issue?
It does :)
Python 3.11.9 (main, Apr 2 2024, 08:25:04) [Clang 15.0.0 (clang-1500.3.9.4)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.lax as lax
>>> import jax.numpy as jnp
>>>
>>> lax.conv_general_dilated(
... jnp.ones((2, 4, 8, 8)),
... jnp.ones((4, 4, 3, 3)),
... window_strides=(1, 1),
... lhs_dilation=(2, 2),
... padding=((1, 1), (1, 1)),
... ).shape
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1714899046.272198 26673924 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
I0000 00:00:1714899046.314145 26673924 service.cc:145] XLA service 0x60000309df00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1714899046.314172 26673924 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1714899046.315841 26673924 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1714899046.315851 26673924 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
(2, 4, 15, 15)