Unable to initialize backend 'METAL'
Description
I ran the Get Started code on the Apple Accelerated JAX training on Mac page, namely:
python3 -m venv ~/jax-metal
source ~/jax-metal/bin/activate
python -m pip install -U pip
python -m pip install numpy wheel ml-dtypes==0.2.0
python -m pip install jax-metal
python -c 'import jax; print(jax.numpy.arange(10))'
On running that last line I get the following error:
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
return lax.iota(dtype, start)
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
return broadcasted_iota(dtype, (size,), 0)
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/benjamv/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
System info (python version, jaxlib version, accelerator, etc.)
Running import jax; jax.print_environment_info() returns the following error:
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:874, in backends()
873 try:
--> 874 backend = _init_backend(platform)
875 _backends[platform] = backend
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:965, in _init_backend(platform)
964 logger.debug("Initializing backend '%s'", platform)
--> 965 backend = registration.factory()
966 # TODO(skye): consider raising more descriptive errors directly from backend
967 # factories instead of returning None.
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:657, in register_plugin.<locals>.factory()
656 if not xla_client.pjrt_plugin_initialized(plugin_name):
--> 657 xla_client.initialize_pjrt_plugin(plugin_name)
658 updated_options = {}
File ~/jax-metal/lib/python3.10/site-packages/jaxlib/xla_client.py:176, in initialize_pjrt_plugin(plugin_name)
169 """Initializes a PJRT plugin.
170
171 The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or
(...)
174 plugin_name: the name of the PJRT plugin.
175 """
--> 176 _xla.initialize_pjrt_plugin(plugin_name)
XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Cell In [2], line 1
----> 1 jax.print_environment_info()
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/environment_info.py:45, in print_environment_info(return_string)
43 python_version = sys.version.replace('\n', ' ')
44 with np.printoptions(threshold=4, edgeitems=2):
---> 45 devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
46 info = textwrap.dedent(
47 f"""\
48 jax: {version.__version__}
(...)
55 """
56 )
57 nvidia_smi = try_nvidia_smi()
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1077, in devices(backend)
1052 def devices(
1053 backend: str | xla_client.Client | None = None
1054 ) -> list[xla_client.Device]:
1055 """Returns a list of all devices for a given backend.
1056
1057 .. currentmodule:: jaxlib.xla_extension
(...)
1075 List of Device subclasses.
1076 """
-> 1077 return get_backend(backend).devices()
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:1011, in get_backend(platform)
1007 @lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
1008 def get_backend(
1009 platform: None | str | xla_client.Client = None
1010 ) -> xla_client.Client:
-> 1011 return _get_backend_uncached(platform)
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:990, in _get_backend_uncached(platform)
986 return platform
988 platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None)
--> 990 bs = backends()
991 if platform is not None:
992 platform = canonicalize_platform(platform)
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/xla_bridge.py:890, in backends()
888 else:
889 err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
--> 890 raise RuntimeError(err_msg)
892 assert _default_backend is not None
893 if not config.jax_platforms.value:
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
Running the command a second time results in:
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
More info
I get the same issue on both my M1 Max MacStudio and M1 2020 MacBook Air. Both running Sonoma 14.5
Getting the same error on M1 Air.
Duplicate of https://github.com/google/jax/issues/20148 ?
pip install jax==0.4.26 jaxlib==0.4.26 gives:
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716453894.367380 849760 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
I0000 00:00:1716453894.452128 849760 service.cc:145] XLA service 0x600002350e00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716453894.452162 849760 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716453894.454461 849760 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716453894.454479 849760 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
loc("-":0:0): error: current mps dialect version is 1.0.0, can't parse version 1.1.0
/AppleInternal/Library/BuildRoots/1dd9a6a2-74cf-11ee-8ed5-2a65a1af8551/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1097: failed assertion `Error importing MLIR bytecode.
'
Abort trap: 6
which is the error in https://github.com/google/jax/issues/20338.
When I ran pip install jax==0.4.26 jaxlib==0.4.26 then I think I got success.
jax.print_environment_info() now gives:
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
and running print(jax.numpy.arange(10)) in an ipython session gives
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716454190.044810 4298556 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
I0000 00:00:1716454190.058568 4298556 service.cc:145] XLA service 0x600000588a00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716454190.058577 4298556 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716454190.059879 4298556 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716454190.059894 4298556 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
[0 1 2 3 4 5 6 7 8 9]
When I ran
pip install jax==0.4.26 jaxlib==0.4.26then I think I got success.
this works for me on M1 MAX. thanks for sharing!
Im trying to install via poetry and I find that there is an issue where Jax-Metal will install a version of jax that cannot be overwritten by specifying an additonal jax dependency:
[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.26", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
When installing this, despite the clear attempt at overriding the jax dependecy jax-metal has it set to 0.4.28
• Installing jax (0.4.28)
• Installing jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl)
When I ran
pip install jax==0.4.26 jaxlib==0.4.26then I think I got success.
In short, this isnt working for me.
for additional information. if I try
$ python3
>>> import jax
jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27.
however if I dont try to override, that is with jax and jaxlib as version 0.4.27:
[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Also take notice that it only updates the jaxlib and not the jax dependency
• Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl)
and do what I tried before:
$ python3
>>> import jax
>>> jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/environment_info.py", line 45, in print_environment_info
devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1077, in devices
return get_backend(backend).devices()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
return _get_backend_uncached(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
bs = backends()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
This same behavior occurs if I change the jax and jaxlib to version 0.4.28 or remove them entirely allowing jax-metal to install the correct jax and jaxlib versions.
machine and enviroment info: Chip: Apple M1 Pro MacOS: Sonoma 14.5 python version: 3.10.13
"RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)". This error comes from jaxlib, which strictly checks the PJRT API version equality. jax-metal 0.0.7 adopts PJRT API from jaxlib-0.4.26. We have been communicated to JAX team and the solution is to set env var ENABLE_PJRT_COMPATIBILITY=1 if running jax-metal with jaxlib>0.4.26. The info can also be found in PYPI jax-metal page: https://pypi.org/project/jax-metal/.
@BeeGass Are you sure your python3 is an arm64 binary? I've been bitten by this more times than I care for. Try
import platform
platform.machine() # should give you 'arm64'
I got mine (M1 Air, Sonoma) working by doing
python -m pip install jax==0.4.26 jaxlib==0.4.26
python -m pip install jax-metal
after I used #19886 to set up my environment (but make sure that your shell is also arm64 before doing this, e.g. arch -arm64 zsh).
@BeeGass Are you sure your
python3is anarm64binary? I've been bitten by this more times than I care for. Tryimport platform platform.machine() # should give you 'arm64'
cleared the cache, all virtual environments and so forth, did a fresh install of all the dependencies ensuring that arm64 is the correct plarform.
$ python
>>> import platform
>>> platform.machine()
'arm64'
for the sake of showing thoroughness:
$ python3
>>> import platform
>>> platform.machine()
'arm64'
Also performed right before install of all dependencies:
arch -arm64 zsh
tried original test again
$ python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/__init__.py", line 37, in <module>
import jax.core as _core
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/core.py", line 18, in <module>
from jax._src.core import (
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 39, in <module>
from jax._src import dtypes
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dtypes.py", line 33, in <module>
from jax._src import config
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/config.py", line 27, in <module>
from jax._src import lib
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 75, in <module>
version = check_jaxlib_version(
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 64, in check_jaxlib_version
raise RuntimeError(msg)
RuntimeError: jaxlib is version 0.4.26, but this version of jax requires version >= 0.4.27.
for the sake of showing thoroughness:
changed the jax and jaxlib dependency versions to 0.4.28 (I know the version of PJRT is within 0.4.26 but given that isnt working I hoped that perhaps the 0.4.28 or 0.4.27 version may have the PJRT version as well.)
[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.27", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
• Downgrading jaxlib (0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.27 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.27-cp310-cp310-macosx_11_0_arm64.whl)
again did the following:
$ python3
>>> import platform
>>> platform.machine()
'arm64'
arch -arm64 zsh
performed the test above
$ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
return lax.iota(dtype, start)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
return broadcasted_iota(dtype, (size,), 0)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
also tried
[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10.0,<=3.10.13"
ml-dtypes = "0.2.0"
jax-metal = { version = "^0.0.7", markers = "platform_machine == 'arm64'" }
jax = { version = "^0.4.28", source = "jax-macos", markers = "platform_machine == 'arm64'" }
jaxlib = { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl", markers = "platform_machine == 'arm64'" }
[[tool.poetry.source]]
name = "jax-macos"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
priority = "primary"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
• Updating jaxlib (0.4.26 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl -> 0.4.28 https://storage.googleapis.com/jax-releases/mac/jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl)
again did the following:
$ python3
>>> import platform
>>> platform.machine()
'arm64'
arch -arm64 zsh
performed the test above
$ python3 -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Traceback (most recent call last):
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
backend = _init_backend(platform)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
backend = registration.factory()
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
xla_client.initialize_pjrt_plugin(plugin_name)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
_xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2968, in arange
return lax.iota(dtype, start)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1282, in iota
return broadcasted_iota(dtype, (size,), 0)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1292, in broadcasted_iota
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 387, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/beegass/Library/Caches/pypoetry/virtualenvs/hippo-pkg-kPsu4IkM-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
outs = fun(*args)
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
I have noticed that the people that have been able to get things working with this fix have been using the M3 chip. perhaps because im using the M1 chip, this could be the issue? Has anyone tried to replicate this on an M1?
just to be clear: Chip: Apple M1 Pro MacOS: Sonoma 14.5
@BeeGass Are you sure your
python3is anarm64binary? I've been bitten by this more times than I care for. Tryimport platform platform.machine() # should give you 'arm64'I got mine (M1 Air, Sonoma) working by doing
python -m pip install jax==0.4.26 jaxlib==0.4.26 python -m pip install jax-metalafter I used #19886 to set up my environment (but make sure that your shell is also
arm64before doing this, e.g.arch -arm64 zsh).
Thanks, this works for me. Mine is M3 Pro, Sonoma 14.5.
@BeeGass, have you tried setting env ENABLE_PJRT_COMPATIBILITY=1 to run jax-metal with jaxlib>0.4.26?
@BeeGass, have you tried setting env ENABLE_PJRT_COMPATIBILITY=1 to run jax-metal with jaxlib>0.4.26?
Yeah still same behavior. Am told that the jax version needs to be equal to or higher than 0.4.27 @shuhand0
When I ran
pip install jax==0.4.26 jaxlib==0.4.26then I think I got success.this works for me on M1 MAX. thanks for sharing!
this works for me on M3. Thank you!!!!!
follow this link, try based on your mac os versions, this works for me !
Having the same issue as @BeeGass , any chance you've found a fix?
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
Works fine if I use versions 4.26 for jax and jaxlib, but unfortunately I need 4.27+ for the package I'm trying to use.
Adding export ENABLE_PJRT_COMPATIBILITY=1 to my ~/.zshrc works. So far so good! (=
NB: I came here because I was getting the error on my M2 Pro, Mac OSX Sonoma 14.6.1 (latest). I also saw this error on Ventura before I updated my OS. I did get it to work with pip install jax==0.4.26 jaxlib==0.4.26 above, but then Flax and Optax both require >= 0.4.27 so there might be runtime errors if using 0.4.26 (unconfirmed).
Question: Is this a good test to see if XLA is working on Metal?
import numpy as np
import jax
import jax.numpy as jnp
@jax.jit
def mult(X, Y):
return jnp.multiply(X, Y)
mat_shape = (3000, 3000)
%timeit mult(jnp.ones(mat_shape), jnp.ones(mat_shape)).block_until_ready()
%timeit np.multiply(np.ones(mat_shape), np.ones(mat_shape))
3.7 ms ± 183 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.2 ms ± 159 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)