error: failed to legalize operation 'mhlo.erf' with jax-metal
Description
This MLP fails to run with jax-metal:
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.typing as jt
class MLP(nn.Module):
mid_features: tuple[int, ...]
out_features: int
@nn.compact
def __call__(
self,
x: jt.ArrayLike,
):
for out_feature_count in self.mid_features:
x = nn.Dense(out_feature_count)(x)
x = nn.relu(x)
return nn.Dense(self.out_features)(x)
mlp = MLP((64, 64, 64), 6)
prng_key = jax.random.key(7)
params = mlp.init(prng_key, jnp.ones((2, 32)))
data = jax.random.uniform(prng_key, (4, 32), float, -10000, 10000)
print(mlp.apply(params, data))
with jax-metal installed:
arno@mba-2 ~/p/reimpl (main) [1]> python test/test_mlp.py
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1725302902.052601 5204059 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
I0000 00:00:1725302902.081618 5204059 service.cc:145] XLA service 0x11fd1e040 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1725302902.081748 5204059 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1725302902.083499 5204059 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1725302902.083613 5204059 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
Traceback (most recent call last):
File "/Users/arno/projects/reimpl/test/test_mlp.py", line 24, in <module>
params = mlp.init(prng_key, jnp.ones((2, 32)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/projects/reimpl/test/test_mlp.py", line 17, in __call__
x = nn.Dense(out_feature_count)(x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/linear.py", line 256, in __call__
kernel = self.param(
^^^^^^^^^^^
File "/Users/arno/venv/rs311/lib/python3.11/site-packages/jax/_src/nn/initializers.py", line 335, in init
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/arno/venv/rs311/lib/python3.11/site-packages/jax/_src/random.py", line 831, in truncated_normal
return _truncated_normal(key, lower, upper, shape, dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/arno/venv/rs311/lib/python3.11/site-packages/flax/core/scope.py:990:14: error: failed to legalize operation 'mhlo.erf'
value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:1889:8: note: called from
v = self.scope.param(name, init_fn, *init_args, unbox=unbox, **init_kwargs)
^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/linear.py:256:13: note: called from
kernel = self.param(
^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:1233:14: note: called from
y = run_fun(self, *args, **kwargs)
^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:701:13: note: called from
return self._call_wrapped_method(fun, args, kwargs)
^
/Users/arno/projects/reimpl/test/test_mlp.py:17:16: note: called from
x = nn.Dense(out_feature_count)(x)
^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:1233:14: note: called from
y = run_fun(self, *args, **kwargs)
^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:701:13: note: called from
return self._call_wrapped_method(fun, args, kwargs)
^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:3103:13: note: called from
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/core/scope.py:1101:10: note: called from
y = fn(root, *args, **kwargs)
^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/core/scope.py:990:14: note: see current operation: %109 = "mhlo.erf"(%108) : (tensor<f32>) -> tensor<f32>
value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
^
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
without jax-metal:
arno@mba-2 ~/p/reimpl (main)> python test/test_mlp.py
[[ 636.52386 882.25415 2988.575 192.267 866.60596 -1293.6633 ]
[ 4585.539 2880.368 2868.1316 78.3667 1750.4458 -857.1804 ]
[ 988.6698 2208.9604 2891.1992 -431.1714 776.54626 -211.66962]
[ 979.84326 4824.3716 6499.3325 321.2257 1804.8367 336.66034]]
System info (python version, jaxlib version, accelerator, etc.)
Python 3.11.9 (v3.11.9:de54cf5be3, Apr 2 2024, 07:12:50) [Clang 13.0.0 (clang-1300.0.29.30)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1725302834.712529 5202857 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
I0000 00:00:1725302834.728122 5202857 service.cc:145] XLA service 0x116f5d340 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1725302834.728249 5202857 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1725302834.729747 5202857 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1725302834.729766 5202857 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
jax: 0.4.31
jaxlib: 0.4.31
numpy: 2.0.1
python: 3.11.9 (v3.11.9:de54cf5be3, Apr 2 2024, 07:12:50) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='mba-2.local', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:21 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T8103', machine='arm64')
Here is a minimal working example which I think gets to the fact that the error function has an issue in jax-metal.
#! /usr/bin/env python
from jax.scipy.special import erf
erf(0)
Results in:
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: ./test.py:5:0: error: failed to legalize operation 'mhlo.erf' ./test.py:5:0: note: see current operation: %0 = "mhlo.erf"(%arg0) : (tensor<f32>) -> tensor<f32>
Using: python 3.12.6 jax 0.4.31 jaxlib 0.4.31 jax_metal 0.1.0
As I said in a discussion on this repo, I'd like to try to contribute these operations, but I don't know how to contribute to jax-metal. The source doesn't seem to be public?
The source is not open. The 'erf' has bee fixed in 0.1.1 patch.
Hi @a-gn,
This issue appears resolved in jax metal 0.1.1 with JAX 0.5.0 and flax 0.10.2. The provided repro produces the expected result with the above mentioned versions. Please find the details in following screenshot.
Thank you.