Luffy-Yao

Results 2 comments of Luffy-Yao

Run the code on the instance with 8 GPUs and make sure the jax[cuda12_pip] is installed in your environment.

Try the following command pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html