grok-1
grok-1 copied to clipboard
Error when installing requirements
i have installed python 3.10 and venv. Trying to "pip install -r requirements.txt"
ERROR: Ignored the following versions that require a different python version: 1.6.2 Requires-Python >=3.7,<3.10; 1.6.3 Requires-Python >=3.7,<3.10; 1.7.0 Requires-Python >=3.7,<3.10; 1.7.1 Requires-Python >=3.7,<3.10 ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25) ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"
I believe this should be fixed now. Can you try again?
I believe this should be fixed now. Can you try again?
No, it doesn't work
WARNING: jax 0.4.25 does not provide the extra 'cuda12-pip' INFO: pip is looking at multiple versions of jax[cuda12-pip] to determine which version is compatible with other requirements. This could take a while. ERROR: Ignored the following versions that require a different python version: 1.6.2 Requires-Python >=3.7,<3.10; 1.6.3 Requires-Python >=3.7,<3.10; 1.7.0 Requires-Python >=3.7,<3.10; 1.7.1 Requires-Python >=3.7,<3.10 ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25) ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"
I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it
Hello @ibab,
I'm getting same error while installing requirements in WSL-2 Kali. Looks like fix doesn't work or I'm doing some kind of mistake while installing requirements. Error message under bellow;
Used command: pip install -r requirements.txt
ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.3, 0.4.4, 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25) ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"
I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it
Hi @kesevone Are you installing via WSL or trying to install on windows? I use wsl and getting same error.
I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it
Hi @kesevone Are you installing via WSL or trying to install on windows? I use wsl and getting same error.
I'm trying to install on windows, now I'll try on wsl
I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it
Hi @kesevone Are you installing via WSL or trying to install on windows? I use wsl and getting same error.
I'm trying to install on windows, now I'll try on wsl
If you successfully install on WSL can you tell me too
pip install dm-haiku
in requirements it's dm_haiku==0.0.12 with underscore ...
I got the same error and try this pip install git+https://github.com/deepmind/dm-haiku
It work on my case.
I get an error on startup
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet. INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)... INFO:rank:Detected 1 devices in mesh Traceback (most recent call last): File "c:\Users\Maksim\Desktop\grok-1\run.py", line 72, in <module> main() File "c:\Users\Maksim\Desktop\grok-1\run.py", line 63, in main inference_runner.initialize() File "c:\Users\Maksim\Desktop\grok-1\runners.py", line 282, in initialize runner.initialize( File "c:\Users\Maksim\Desktop\grok-1\runners.py", line 181, in initialize self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "c:\Users\Maksim\Desktop\grok-1\runners.py", line 586, in make_mesh device_mesh = mesh_utils.create_hybrid_device_mesh( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Maksim\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\experimental\mesh_utils.py", line 373, in create_hybrid_device_mesh per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Maksim\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\experimental\mesh_utils.py", line 373, in <listcomp> per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Maksim\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\experimental\mesh_utils.py", line 302, in create_device_mesh raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)
That just means you don't have the appropriate number of devices. The mesh_shape is the configuration for what is expected, in this case 8 devices to distribute the model over and run inference on. If you don't allocate exactly 8 gpus, it will not work, granted running inference with this model will require a minimum of 8 large GPUs anyway.
@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)
I also have the same problem. Have you fixed it now?
I got the same error, but running only:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
worked for me, so it's probably because of the dm_haiku problem as described above
@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)
In theory that should work in terms of not breaking at that point, but you won't have enough memory to load onto a single GPU, so you will get a CUDA OOM errror.
same issue
@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)
In theory that should work in terms of not breaking at that point, but you won't have enough memory to load onto a single GPU, so you will get a CUDA OOM errror.
yep, I have up.. 300GB dowloaded for nothing :D
running these commands after the error, fix the installation issue
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt
@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)
In theory that should work in terms of not breaking at that point, but you won't have enough memory to load onto a single GPU, so you will get a CUDA OOM errror.
yep, I have up.. 300GB dowloaded for nothing :D
If you want to try it and don’t have access to a 8 gpu cluster, there are cloud compute solutions with AWS sagemaker EC2 instances, lambda labs, coreweave, and a few more where you might be able to get an 8xA100 80GB (640GB total) allocation.
change requirements.txt dm-haiku==0.0.12 jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html numpy==1.26.4 sentencepiece==0.2.0
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda':
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 1 devices in mesh
Traceback (most recent call last):
File "/workspace/grok-1/run.py", line 72, in <module>
main()
File "/workspace/grok-1/run.py", line 63, in main
inference_runner.initialize()
File "/workspace/grok-1/runners.py", line 282, in initialize
runner.initialize(
File "/workspace/grok-1/runners.py", line 181, in initialize
self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/grok-1/runners.py", line 586, in make_mesh
device_mesh = mesh_utils.create_hybrid_device_mesh(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.pyenv_mirror/user/current/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh
per_granule_meshes = [create_device_mesh(mesh_shape, granule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/.pyenv_mirror/user/current/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh
raise ValueError(f'Number of devices {len(devices)} must equal the product '
ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)
How do you solve this?
After changing requirements.txt
to:
dm-haiku==0.0.12
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
numpy==1.26.4
sentencepiece==0.2.0
and then running pip install -r requirements.txt
, it ˜worked˜. But when I run python3 run.py
I just got this new issue:
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda':
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 1 devices in mesh
Traceback (most recent call last):
File "/Users/matheuscardoso/Projects/grok-1/run.py", line 72, in <module>
main()
File "/Users/matheuscardoso/Projects/grok-1/run.py", line 63, in main
inference_runner.initialize()
File "/Users/matheuscardoso/Projects/grok-1/runners.py", line 282, in initialize
runner.initialize(
File "/Users/matheuscardoso/Projects/grok-1/runners.py", line 181, in initialize
self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/matheuscardoso/Projects/grok-1/runners.py", line 586, in make_mesh
device_mesh = mesh_utils.create_hybrid_device_mesh(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh
per_granule_meshes = [create_device_mesh(mesh_shape, granule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh
raise ValueError(f'Number of devices {len(devices)} must equal the product '
ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)
Did anyone solved it?
I was able to run Grok-1 yesterday. As people have commented, what did the trick for us at CloudWalk (a Brazilian fintech) was to use our K8 cluster with at least 8xA100 GPUs (80 GB family). Grok-1 uses almost all the memory from the GPUs (so using only 1 or 2 GPUs will not give you enough memory).
Another thing that solved our problems was running: pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Then, we just needed to run python run.py
, and voilà.
You can also pull this container to run grok: ghcr.io/nvidia/jax:grok
from JAX Toolbox