grok-1 icon indicating copy to clipboard operation
grok-1 copied to clipboard

Error when installing requirements

Open kesevone opened this issue 5 months ago • 24 comments

i have installed python 3.10 and venv. Trying to "pip install -r requirements.txt"

ERROR: Ignored the following versions that require a different python version: 1.6.2 Requires-Python >=3.7,<3.10; 1.6.3 Requires-Python >=3.7,<3.10; 1.7.0 Requires-Python >=3.7,<3.10; 1.7.1 Requires-Python >=3.7,<3.10 ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25) ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"

kesevone avatar Mar 17 '24 20:03 kesevone

I believe this should be fixed now. Can you try again?

ibab avatar Mar 17 '24 20:03 ibab

I believe this should be fixed now. Can you try again?

No, it doesn't work

WARNING: jax 0.4.25 does not provide the extra 'cuda12-pip' INFO: pip is looking at multiple versions of jax[cuda12-pip] to determine which version is compatible with other requirements. This could take a while. ERROR: Ignored the following versions that require a different python version: 1.6.2 Requires-Python >=3.7,<3.10; 1.6.3 Requires-Python >=3.7,<3.10; 1.7.0 Requires-Python >=3.7,<3.10; 1.7.1 Requires-Python >=3.7,<3.10 ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25) ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"

kesevone avatar Mar 17 '24 20:03 kesevone

I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it

kesevone avatar Mar 17 '24 21:03 kesevone

Hello @ibab,

I'm getting same error while installing requirements in WSL-2 Kali. Looks like fix doesn't work or I'm doing some kind of mistake while installing requirements. Error message under bellow;

Used command: pip install -r requirements.txt

ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip" (from jax[cuda12-pip]) (from versions: 0.4.3, 0.4.4, 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25) ERROR: No matching distribution found for jaxlib==0.4.25+cuda12.cudnn89; extra == "cuda12_pip"

alpernae avatar Mar 17 '24 21:03 alpernae

I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it

Hi @kesevone Are you installing via WSL or trying to install on windows? I use wsl and getting same error.

alpernae avatar Mar 17 '24 21:03 alpernae

I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it

Hi @kesevone Are you installing via WSL or trying to install on windows? I use wsl and getting same error.

I'm trying to install on windows, now I'll try on wsl

kesevone avatar Mar 17 '24 21:03 kesevone

I think this is a problem in my system, I have Windows, there is no CUDA in jax for Windows as I understand it

Hi @kesevone Are you installing via WSL or trying to install on windows? I use wsl and getting same error.

I'm trying to install on windows, now I'll try on wsl

If you successfully install on WSL can you tell me too

alpernae avatar Mar 17 '24 21:03 alpernae

pip install dm-haiku

in requirements it's dm_haiku==0.0.12 with underscore ...

yarodevuci avatar Mar 17 '24 21:03 yarodevuci

I got the same error and try this pip install git+https://github.com/deepmind/dm-haiku

It work on my case.

ahsan3219 avatar Mar 17 '24 22:03 ahsan3219

I get an error on startup INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet. INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)... INFO:rank:Detected 1 devices in mesh Traceback (most recent call last): File "c:\Users\Maksim\Desktop\grok-1\run.py", line 72, in <module> main() File "c:\Users\Maksim\Desktop\grok-1\run.py", line 63, in main inference_runner.initialize() File "c:\Users\Maksim\Desktop\grok-1\runners.py", line 282, in initialize runner.initialize( File "c:\Users\Maksim\Desktop\grok-1\runners.py", line 181, in initialize self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "c:\Users\Maksim\Desktop\grok-1\runners.py", line 586, in make_mesh device_mesh = mesh_utils.create_hybrid_device_mesh( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Maksim\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\experimental\mesh_utils.py", line 373, in create_hybrid_device_mesh per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Maksim\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\experimental\mesh_utils.py", line 373, in <listcomp> per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Maksim\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\experimental\mesh_utils.py", line 302, in create_device_mesh raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)

hidenway avatar Mar 17 '24 22:03 hidenway

That just means you don't have the appropriate number of devices. The mesh_shape is the configuration for what is expected, in this case 8 devices to distribute the model over and run inference on. If you don't allocate exactly 8 gpus, it will not work, granted running inference with this model will require a minimum of 8 large GPUs anyway.

AndreSlavescu avatar Mar 17 '24 23:03 AndreSlavescu

@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)

yarodevuci avatar Mar 18 '24 04:03 yarodevuci

I also have the same problem. Have you fixed it now?

lbg-686 avatar Mar 18 '24 07:03 lbg-686

I got the same error, but running only:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

worked for me, so it's probably because of the dm_haiku problem as described above

felifri avatar Mar 18 '24 13:03 felifri

@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)

In theory that should work in terms of not breaking at that point, but you won't have enough memory to load onto a single GPU, so you will get a CUDA OOM errror.

AndreSlavescu avatar Mar 18 '24 16:03 AndreSlavescu

same issue

imfunniee avatar Mar 18 '24 16:03 imfunniee

@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)

In theory that should work in terms of not breaking at that point, but you won't have enough memory to load onto a single GPU, so you will get a CUDA OOM errror.

yep, I have up.. 300GB dowloaded for nothing :D

yarodevuci avatar Mar 18 '24 17:03 yarodevuci

running these commands after the error, fix the installation issue

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt

ywiyogo avatar Mar 19 '24 00:03 ywiyogo

@AndreSlavescu I have tried with 1. i did mesh_shape (1, 1) instead mesh_shape (1, 8)

In theory that should work in terms of not breaking at that point, but you won't have enough memory to load onto a single GPU, so you will get a CUDA OOM errror.

yep, I have up.. 300GB dowloaded for nothing :D

If you want to try it and don’t have access to a 8 gpu cluster, there are cloud compute solutions with AWS sagemaker EC2 instances, lambda labs, coreweave, and a few more where you might be able to get an 8xA100 80GB (640GB total) allocation.

AndreSlavescu avatar Mar 19 '24 01:03 AndreSlavescu

change requirements.txt dm-haiku==0.0.12 jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html numpy==1.26.4 sentencepiece==0.2.0

guobi777 avatar Mar 19 '24 05:03 guobi777

INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': 
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 1 devices in mesh
Traceback (most recent call last):
  File "/workspace/grok-1/run.py", line 72, in <module>
    main()
  File "/workspace/grok-1/run.py", line 63, in main
    inference_runner.initialize()
  File "/workspace/grok-1/runners.py", line 282, in initialize
    runner.initialize(
  File "/workspace/grok-1/runners.py", line 181, in initialize
    self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/grok-1/runners.py", line 586, in make_mesh
    device_mesh = mesh_utils.create_hybrid_device_mesh(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.pyenv_mirror/user/current/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh
    per_granule_meshes = [create_device_mesh(mesh_shape, granule)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.pyenv_mirror/user/current/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh
    raise ValueError(f'Number of devices {len(devices)} must equal the product '
ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)


How do you solve this?

Jintao97 avatar Mar 19 '24 07:03 Jintao97

After changing requirements.txt to:

dm-haiku==0.0.12
jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
numpy==1.26.4
sentencepiece==0.2.0

and then running pip install -r requirements.txt, it ˜worked˜. But when I run python3 run.py I just got this new issue:

INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda':
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)...
INFO:rank:Detected 1 devices in mesh
Traceback (most recent call last):
  File "/Users/matheuscardoso/Projects/grok-1/run.py", line 72, in <module>
    main()
  File "/Users/matheuscardoso/Projects/grok-1/run.py", line 63, in main
    inference_runner.initialize()
  File "/Users/matheuscardoso/Projects/grok-1/runners.py", line 282, in initialize
    runner.initialize(
  File "/Users/matheuscardoso/Projects/grok-1/runners.py", line 181, in initialize
    self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/matheuscardoso/Projects/grok-1/runners.py", line 586, in make_mesh
    device_mesh = mesh_utils.create_hybrid_device_mesh(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh
    per_granule_meshes = [create_device_mesh(mesh_shape, granule)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh
    raise ValueError(f'Number of devices {len(devices)} must equal the product '
ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)

Did anyone solved it?

user-matth avatar Mar 19 '24 13:03 user-matth

I was able to run Grok-1 yesterday. As people have commented, what did the trick for us at CloudWalk (a Brazilian fintech) was to use our K8 cluster with at least 8xA100 GPUs (80 GB family). Grok-1 uses almost all the memory from the GPUs (so using only 1 or 2 GPUs will not give you enough memory).

Another thing that solved our problems was running: pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then, we just needed to run python run.py, and voilà.

cw-lucasgabriel avatar Mar 19 '24 14:03 cw-lucasgabriel

You can also pull this container to run grok: ghcr.io/nvidia/jax:grok from JAX Toolbox

sbhavani avatar May 01 '24 18:05 sbhavani