Luffy-Yao
Results
2
comments of
Luffy-Yao
Run the code on the instance with 8 GPUs and make sure the jax[cuda12_pip] is installed in your environment.
Try the following command pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html