jax icon indicating copy to clipboard operation
jax copied to clipboard

error: failed to legalize operation 'mhlo.erf' with jax-metal

Open a-gn opened this issue 1 year ago • 3 comments

Description

This MLP fails to run with jax-metal:

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.typing as jt


class MLP(nn.Module):
    mid_features: tuple[int, ...]
    out_features: int

    @nn.compact
    def __call__(
        self,
        x: jt.ArrayLike,
    ):
        for out_feature_count in self.mid_features:
            x = nn.Dense(out_feature_count)(x)
            x = nn.relu(x)
        return nn.Dense(self.out_features)(x)


mlp = MLP((64, 64, 64), 6)
prng_key = jax.random.key(7)
params = mlp.init(prng_key, jnp.ones((2, 32)))
data = jax.random.uniform(prng_key, (4, 32), float, -10000, 10000)
print(mlp.apply(params, data))

with jax-metal installed:

arno@mba-2 ~/p/reimpl (main) [1]> python test/test_mlp.py 
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1725302902.052601 5204059 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1725302902.081618 5204059 service.cc:145] XLA service 0x11fd1e040 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1725302902.081748 5204059 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1725302902.083499 5204059 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1725302902.083613 5204059 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
Traceback (most recent call last):
  File "/Users/arno/projects/reimpl/test/test_mlp.py", line 24, in <module>
    params = mlp.init(prng_key, jnp.ones((2, 32)))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/projects/reimpl/test/test_mlp.py", line 17, in __call__
    x = nn.Dense(out_feature_count)(x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/linear.py", line 256, in __call__
    kernel = self.param(
             ^^^^^^^^^^^
  File "/Users/arno/venv/rs311/lib/python3.11/site-packages/jax/_src/nn/initializers.py", line 335, in init
    return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/rs311/lib/python3.11/site-packages/jax/_src/random.py", line 831, in truncated_normal
    return _truncated_normal(key, lower, upper, shape, dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/arno/venv/rs311/lib/python3.11/site-packages/flax/core/scope.py:990:14: error: failed to legalize operation 'mhlo.erf'
      value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
             ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:1889:8: note: called from
    v = self.scope.param(name, init_fn, *init_args, unbox=unbox, **init_kwargs)
       ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/linear.py:256:13: note: called from
    kernel = self.param(
            ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:1233:14: note: called from
          y = run_fun(self, *args, **kwargs)
             ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:701:13: note: called from
      return self._call_wrapped_method(fun, args, kwargs)
            ^
/Users/arno/projects/reimpl/test/test_mlp.py:17:16: note: called from
            x = nn.Dense(out_feature_count)(x)
               ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:1233:14: note: called from
          y = run_fun(self, *args, **kwargs)
             ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:701:13: note: called from
      return self._call_wrapped_method(fun, args, kwargs)
            ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:3103:13: note: called from
      return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
            ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/core/scope.py:1101:10: note: called from
      y = fn(root, *args, **kwargs)
         ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/core/scope.py:990:14: note: see current operation: %109 = "mhlo.erf"(%108) : (tensor<f32>) -> tensor<f32>
      value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
             ^

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

without jax-metal:

arno@mba-2 ~/p/reimpl (main)> python test/test_mlp.py
[[  636.52386   882.25415  2988.575     192.267     866.60596 -1293.6633 ]
 [ 4585.539    2880.368    2868.1316     78.3667   1750.4458   -857.1804 ]
 [  988.6698   2208.9604   2891.1992   -431.1714    776.54626  -211.66962]
 [  979.84326  4824.3716   6499.3325    321.2257   1804.8367    336.66034]]

System info (python version, jaxlib version, accelerator, etc.)

Python 3.11.9 (v3.11.9:de54cf5be3, Apr  2 2024, 07:12:50) [Clang 13.0.0 (clang-1300.0.29.30)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1725302834.712529 5202857 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1725302834.728122 5202857 service.cc:145] XLA service 0x116f5d340 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1725302834.728249 5202857 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1725302834.729747 5202857 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1725302834.729766 5202857 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
jax:    0.4.31
jaxlib: 0.4.31
numpy:  2.0.1
python: 3.11.9 (v3.11.9:de54cf5be3, Apr  2 2024, 07:12:50) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='mba-2.local', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:21 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T8103', machine='arm64')

a-gn avatar Sep 02 '24 18:09 a-gn

Here is a minimal working example which I think gets to the fact that the error function has an issue in jax-metal.

#! /usr/bin/env python

from jax.scipy.special import erf

erf(0)

Results in:

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: ./test.py:5:0: error: failed to legalize operation 'mhlo.erf' ./test.py:5:0: note: see current operation: %0 = "mhlo.erf"(%arg0) : (tensor<f32>) -> tensor<f32>

Using: python 3.12.6 jax 0.4.31 jaxlib 0.4.31 jax_metal 0.1.0

TheSkyentist avatar Sep 30 '24 20:09 TheSkyentist

As I said in a discussion on this repo, I'd like to try to contribute these operations, but I don't know how to contribute to jax-metal. The source doesn't seem to be public?

a-gn avatar Oct 13 '24 09:10 a-gn

The source is not open. The 'erf' has bee fixed in 0.1.1 patch.

shuhand0 avatar Oct 14 '24 17:10 shuhand0

Hi @a-gn,

This issue appears resolved in jax metal 0.1.1 with JAX 0.5.0 and flax 0.10.2. The provided repro produces the expected result with the above mentioned versions. Please find the details in following screenshot.

Image Image

Thank you.

rajasekharporeddy avatar Feb 10 '25 10:02 rajasekharporeddy