jax
jax copied to clipboard
Jax metal failed to install
Description
Using the instructions on the pip website the jax_metal failed to install
(base) jakub.dokulil@nbm-imp-134 jd_python_learning % conda create -n jax_metal python=3.10
Channels:
- defaults
Platform: osx-64
Collecting package metadata (repodata.json): done
Solving environment: done
## Package Plan ##
environment location: /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal
added / updated specs:
- python=3.10
The following NEW packages will be INSTALLED:
bzip2 pkgs/main/osx-64::bzip2-1.0.8-h1de35cc_0
ca-certificates pkgs/main/osx-64::ca-certificates-2023.12.12-hecd8cb5_0
libffi pkgs/main/osx-64::libffi-3.4.4-hecd8cb5_0
ncurses pkgs/main/osx-64::ncurses-6.4-hcec6c5f_0
openssl pkgs/main/osx-64::openssl-3.0.13-hca72f7f_0
pip pkgs/main/osx-64::pip-23.3.1-py310hecd8cb5_0
python pkgs/main/osx-64::python-3.10.13-h5ee71fb_0
readline pkgs/main/osx-64::readline-8.2-hca72f7f_0
setuptools pkgs/main/osx-64::setuptools-68.2.2-py310hecd8cb5_0
sqlite pkgs/main/osx-64::sqlite-3.41.2-h6c40b1e_0
tk pkgs/main/osx-64::tk-8.6.12-h5d9f67b_0
tzdata pkgs/main/noarch::tzdata-2023d-h04d1e81_0
wheel pkgs/main/osx-64::wheel-0.41.2-py310hecd8cb5_0
xz pkgs/main/osx-64::xz-5.4.5-h6c40b1e_0
zlib pkgs/main/osx-64::zlib-1.2.13-h4dc903c_0
Proceed ([y]/n)?
Downloading and Extracting Packages:
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
#
# To activate this environment, use
#
# $ conda activate jax_metal
#
# To deactivate an active environment, use
#
# $ conda deactivate
(base) jakub.dokulil@nbm-imp-134 jd_python_learning % conda activate jax_metal
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -m pip install -U pip
Requirement already satisfied: pip in /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages (23.3.1)
Collecting pip
Using cached pip-24.0-py3-none-any.whl.metadata (3.6 kB)
Using cached pip-24.0-py3-none-any.whl (2.1 MB)
Installing collected packages: pip
Attempting uninstall: pip
Found existing installation: pip 23.3.1
Uninstalling pip-23.3.1:
Successfully uninstalled pip-23.3.1
Successfully installed pip-24.0
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -m pip install numpy
Collecting numpy
Using cached numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl (20.6 MB)
Installing collected packages: numpy
Successfully installed numpy-1.26.4
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -m pip install jax-metal
Collecting jax-metal
Using cached jax_metal-0.0.5-py3-none-macosx_10_14_x86_64.whl.metadata (1.4 kB)
Requirement already satisfied: wheel~=0.35 in /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages (from jax-metal) (0.41.2)
Collecting six>=1.15.0 (from jax-metal)
Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting jax==0.4.20 (from jax-metal)
Using cached jax-0.4.20-py3-none-any.whl.metadata (23 kB)
Collecting jaxlib==0.4.20 (from jax-metal)
Downloading jaxlib-0.4.20-cp310-cp310-macosx_10_14_x86_64.whl.metadata (2.1 kB)
Collecting ml-dtypes>=0.2.0 (from jax==0.4.20->jax-metal)
Using cached ml_dtypes-0.3.2-cp310-cp310-macosx_10_9_universal2.whl.metadata (20 kB)
Requirement already satisfied: numpy>=1.22 in /Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages (from jax==0.4.20->jax-metal) (1.26.4)
Collecting opt-einsum (from jax==0.4.20->jax-metal)
Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Collecting scipy>=1.9 (from jax==0.4.20->jax-metal)
Using cached scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl.metadata (60 kB)
Using cached jax_metal-0.0.5-py3-none-macosx_10_14_x86_64.whl (54.6 MB)
Using cached jax-0.4.20-py3-none-any.whl (1.7 MB)
Downloading jaxlib-0.4.20-cp310-cp310-macosx_10_14_x86_64.whl (82.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 82.6/82.6 MB 3.5 MB/s eta 0:00:00
Using cached ml_dtypes-0.3.2-cp310-cp310-macosx_10_9_universal2.whl (389 kB)
Using cached scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl (38.9 MB)
Installing collected packages: six, scipy, opt-einsum, ml-dtypes, jaxlib, jax, jax-metal
Successfully installed jax-0.4.20 jax-metal-0.0.5 jaxlib-0.4.20 ml-dtypes-0.3.2 opt-einsum-3.3.0 scipy-1.12.0 six-1.16.0
(jax_metal) jakub.dokulil@nbm-imp-134 jd_python_learning % python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/__init__.py", line 39, in <module>
from jax import config as _config_module
File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/config.py", line 15, in <module>
from jax._src.config import config as _deprecated_config # noqa: F401
File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/_src/config.py", line 28, in <module>
from jax._src import lib
File "/Users/jakub.dokulil/opt/anaconda3/envs/jax_metal/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 83, in <module>
cpu_feature_guard.check_cpu_features()
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.
System info (python version, jaxlib version, accelerator, etc.)
Macbook Air M2 Macos Sonoma 14.3.1 (23D60) Python 3.10
Based on the packages, it is AMD GPU? Could you try a venv with python=3.9?
Reproduces on my m2 mac. with both py 3.10.6 and 3.9.13
Tried jax==0.4.11 jaxlib==0.4.11 jax-metal==0.0.4 - same thing
Haven't been able to reproduce the issue. The below config shows an installation and verification result: ProductName: macOS ProductVersion: 14.4
The following NEW packages will be INSTALLED:
ca-certificates pkgs/main/osx-64::ca-certificates-2023.12.12-hecd8cb5_0
libcxx pkgs/main/osx-64::libcxx-14.0.6-h9765a3e_0
libffi pkgs/main/osx-64::libffi-3.4.4-hecd8cb5_0
ncurses pkgs/main/osx-64::ncurses-6.4-hcec6c5f_0
openssl pkgs/main/osx-64::openssl-3.0.13-hca72f7f_0
pip pkgs/main/osx-64::pip-23.3.1-py39hecd8cb5_0
python pkgs/main/osx-64::python-3.9.18-h5ee71fb_0
readline pkgs/main/osx-64::readline-8.2-hca72f7f_0
setuptools pkgs/main/osx-64::setuptools-68.2.2-py39hecd8cb5_0
sqlite pkgs/main/osx-64::sqlite-3.41.2-h6c40b1e_0
tk pkgs/main/osx-64::tk-8.6.12-h5d9f67b_0
tzdata pkgs/main/noarch::tzdata-2024a-h04d1e81_0
wheel pkgs/main/osx-64::wheel-0.41.2-py39hecd8cb5_0
xz pkgs/main/osx-64::xz-5.4.6-h6c40b1e_0
zlib pkgs/main/osx-64::zlib-1.2.13-h4dc903c_0
Package Version
------------------ -------
importlib_metadata 7.0.2
jax 0.4.20
jax-metal 0.0.5
jaxlib 0.4.20
ml-dtypes 0.3.2
numpy 1.26.4
opt-einsum 3.3.0
pip 24.0
scipy 1.12.0
setuptools 68.2.2
six 1.16.0
wheel 0.41.2
zipp 3.17.0
python -c 'import jax; print(jax.numpy.arange(10))'
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-08 17:33:36.946600: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: AMD Radeon Pro Vega 20
systemMemory: 32.00 GB
maxCacheSize: 1.99 GB
[0 1 2 3 4 5 6 7 8 9]
Right, i think i was able to figure it out - in my case it was due python being i386
arch and not arm64
.
After switching arch and installing native python, it worked.
Right, i think i was able to figure it out - in my case it was due python being
i386
arch and notarm64
. After switching arch and installing native python, it worked.
I have just tried to install following the instructions in the apple website (https://developer.apple.com/metal/jax/) and it failed. Same error than everyone here in a M2. How did you switched your native python3?
I have just ran the following code:
import platform
# Check the machine architecture
machine = platform.machine()
if machine == 'arm64':
print("Your Python version is ARM64")
elif machine == 'i386':
print("Your Python version is i386 (32-bit)")
elif machine == 'x86_64':
print("Your Python version is x86_64 (64-bit)")
else:
print(f"Unknown machine architecture: {machine}")
and the print out is:
Your Python version is x86_64 (64-bit)
@phisanti you switch in you CLI with arch
command, then you install python afresh (it will be a different python) and go with jax m install instruct from apple.
@curlup thanks for the tip. It worked for me!