trax icon indicating copy to clipboard operation
trax copied to clipboard

GPU/TPU not found

Open lonewolfnsp opened this issue 3 years ago • 4 comments

I've been fumbling around the whole day trying to get Trax to use GPU but FAILED..

is there some kind of clear guidelines on how to get Trax to see the GPU? I came across the Trax framework from a course I took on NLP where one of the instructor was Lucas, the co-author of the famous transformer article "Attention is all you need". Trax does look nice.. but there is no clear guide on how to get it to work with the GPU and this is very annoying as the notebookks from the course have pretty complex models (transformers, reformer) among others..

My environment OS: Linux mint 19.3 Trax: 1.3.9 Python: 3.8.10 GPU: Titax X with driver 460

basically I created a new conda env: conda create -n trax python=3.8 them activate that env: conda activate trax then install trax: pip install trax

read online that need to install jaxlib=0.1.57 and cuda 11.2 but cant't get it to work.. kernel crashes..

please, can someone help me?

...

Environment information

OS: <your answer here>

$ pip freeze | grep trax
# your output here

$ pip freeze | grep tensor
# your output here

$ pip freeze | grep jax
# your output here

$ python -V
# your output here

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
...

lonewolfnsp avatar May 31 '21 18:05 lonewolfnsp

I haven't had the time to test and report it extensively, but for me trax==1.3.9 has not worked as expected (a Kaggle notebook timed out without completing 1k batches) and until I have time to investigate and report my findings, I'm running trax==1.3.8. As for GPU detection, trax does it out of the box. Could you try downgrading trax to 1.3.8?

csengor avatar Jun 01 '21 07:06 csengor

thanks for responding.. I've tried installing 1.3.8 but when I tried to train it's still reporting no GPU/TPU detecrted.. These are the steps I took:

  1. setup conda env: conda create -n trax python=3.7
    1. activate env: conda asctivate trax
    1. oip install trax==1.3.8

using conda list, I notice that cudnn is not present.. neither is tensorflow-gpu

please, can someone advice me what do I need to do to get trax to use the GPU for training?

lonewolfnsp avatar Jun 01 '21 15:06 lonewolfnsp

you need to install a gpu version of jax, details, see https://github.com/google/jax#installation

gdkrmr avatar Jun 25 '21 10:06 gdkrmr

Hi, thanks for responding to my queries. Yes, I have installed the jaxlib+cuda111. In fact, I have ran this notebook locally using my trax environment: https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb

it's actually a JAX notebook that shows how classifier can be written using JAX. training it locally, I can see that my GPU utilization stays above 10% throughout the training, hitting 30+% most of the time, and the epoch time is 1+ sec. This proves that my jaxlib+cuda111 is indeed using the GPU and cuda toolkit is setup correctly.

Now, if you refer to the trax examples on fashion MNIST and wide resnet, you'll notice that it's using tensorflow-numpy backend, with this code in the first cell: trax.fastmath.set_backend('tensorflow-numpy'). I ran these 2 notebooks and I could see that training will use the GPU, with the GPU utilization at 100% most of the time during training. My timing is 46 sec/100 steps for the fashion MNIST when using the tensorflow-numpy backend. but if I comment away that line, it defaults to jax backend, it's a totally different story with the GPU utilization dropping to 0%, CPU utilization became very high, training time became painfully slow..

This is very puzzling because running JAX notebook shows that JAX/JAXLIB can use the GPU to train with no issues, using tensorflow-numpy can also use GPU for training, but why is it that when Trax uses jax backend, it does not use GPU? this could be something wrong with the way trax is using JAX. I realiza that if the jit in jax is not used, it's very slow. Could this be the reason trax is not using the GPU?

if you refer to the NMT example which uses TPU in colab, there's a chunk of code to setup the TPU backend for trax, and I've tried it in colab with TPU, it does work, in fact, when I modified the codes to use tensorflow-numpy backend, I can even train locally and it does use the GPU, and the steps take nearly the same time as TPU on colab.

if you've any advice on how to get trax to use GPU for jax backend, please advice me..

one thing I don't understand, if JAX in trax can use GPU, why then is there a need for tensorflow-numpy backend?

On Fri, Jun 25, 2021 at 6:50 PM Guido Kraemer @.***> wrote:

you need to install a gpu version of jax, details, see https://github.com/google/jax#installation

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/trax/issues/1652#issuecomment-868412926, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABONUQWSHFSOFLRBCJIZPQTTURNODANCNFSM453IVTCQ .

lonewolfnsp avatar Jun 25 '21 16:06 lonewolfnsp