trax
trax copied to clipboard
Trax ML: GPU memory allocated but completed on CPU
Description
Hi, I have run the example ende translation script and have installed jax+cuda so that I do not get the typical "No GPU/TPU found, falling back to CPU" error. But using 'nvidia-smi' and 'top' it appears most of my GPU memory is being allocated by jax but the GPU itself is not being used and instead my computer CPU is working at 100%. I have checked which device jax is using and it is saying gpu:0. Eventually it does translate, but it is very slow and I strongly suspect the computation is happening on my CPU as shown below. Why might this be?
Environment information
OS: Ubuntu 20.04 CUDA 11.2 Jax 0.2.17 Trax 1.3.9
$ pip freeze | grep trax trax==1.3.9
$ pip freeze | grep tensor mesh-tensorflow==0.1.19 tensorboard==2.5.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorflow==2.5.0 tensorflow-datasets==4.3.0 tensorflow-estimator==2.5.0 tensorflow-hub==0.12.0 tensorflow-metadata==1.1.0 tensorflow-text==2.5.0
$ pip freeze | grep jax jax==0.2.17 jaxlib==0.1.68+cuda101
$ python3 -V Python 3.8.10
>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device_buffer.device())
gpu:0
Nvidia-smi
NVIDIA GeForce GTX 1650 Ti Memory-Usage 3876/3914 MiB Voltatile GPU-Util: 9%
GPU Memory Allocation
+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 1046 G /usr/lib/xorg/Xorg 45MiB | | 0 N/A N/A 1621 G /usr/lib/xorg/Xorg 124MiB | | 0 N/A N/A 1795 G /usr/bin/gnome-shell 100MiB | | 0 N/A N/A 2981 C python3 3541MiB |
CPU Usage
Command: python3 %CPU: 100.0 %MEM: 19.6
I'm having the exact same problem. When I run the Tensorboard profiler/trace viewer, it shows a tiny bit of startup activity in the GPU, then nothing - it's all CPU from there. I can run a raw jax loop and the GPU is used fine in the same notebook. I'm at my wits end with this - Trax is way, way too slow to be useful on the CPU. How is anyone getting this to work?
I experience this issue too, however is on training. I have try to install different jax and jaxlib version, but no help. I have no idea, anyone can help?
I'm facing the same exact problem! I've been investigating this issue for 3 days and I couldn't find a solution!
Although when I use Jax alone or tensor-flow alone and monitor the GPU, I see that they are using the GPU properly! but from Trax just the ('tensor-flow numpy backend') is what using the GPU (memory and computation) but when I set the backend to ('jax') then just the memory is used without any computation!
Any help?
Same problem for me, please post if you found a solution.
setting "trax.fastmath.set_backend('tensorflow-numpy')" seems to help, I can see the gpu cycles being used.