trax icon indicating copy to clipboard operation
trax copied to clipboard

Trax ML: GPU memory allocated but completed on CPU

Open ghost opened this issue 3 years ago • 5 comments

Description

Hi, I have run the example ende translation script and have installed jax+cuda so that I do not get the typical "No GPU/TPU found, falling back to CPU" error. But using 'nvidia-smi' and 'top' it appears most of my GPU memory is being allocated by jax but the GPU itself is not being used and instead my computer CPU is working at 100%. I have checked which device jax is using and it is saying gpu:0. Eventually it does translate, but it is very slow and I strongly suspect the computation is happening on my CPU as shown below. Why might this be?

Environment information

OS: Ubuntu 20.04 CUDA 11.2 Jax 0.2.17 Trax 1.3.9

$ pip freeze | grep trax trax==1.3.9

$ pip freeze | grep tensor mesh-tensorflow==0.1.19 tensorboard==2.5.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorflow==2.5.0 tensorflow-datasets==4.3.0 tensorflow-estimator==2.5.0 tensorflow-hub==0.12.0 tensorflow-metadata==1.1.0 tensorflow-text==2.5.0

$ pip freeze | grep jax jax==0.2.17 jaxlib==0.1.68+cuda101

$ python3 -V Python 3.8.10

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device_buffer.device())
gpu:0

Nvidia-smi

NVIDIA GeForce GTX 1650 Ti Memory-Usage 3876/3914 MiB Voltatile GPU-Util: 9%

GPU Memory Allocation

+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 1046 G /usr/lib/xorg/Xorg 45MiB | | 0 N/A N/A 1621 G /usr/lib/xorg/Xorg 124MiB | | 0 N/A N/A 1795 G /usr/bin/gnome-shell 100MiB | | 0 N/A N/A 2981 C python3 3541MiB |

CPU Usage

Command: python3 %CPU: 100.0 %MEM: 19.6

ghost avatar Jul 30 '21 11:07 ghost

I'm having the exact same problem. When I run the Tensorboard profiler/trace viewer, it shows a tiny bit of startup activity in the GPU, then nothing - it's all CPU from there. I can run a raw jax loop and the GPU is used fine in the same notebook. I'm at my wits end with this - Trax is way, way too slow to be useful on the CPU. How is anyone getting this to work?

OtwellResearch avatar Feb 15 '22 20:02 OtwellResearch

I experience this issue too, however is on training. I have try to install different jax and jaxlib version, but no help. I have no idea, anyone can help?

ccmehk avatar Apr 04 '22 09:04 ccmehk

I'm facing the same exact problem! I've been investigating this issue for 3 days and I couldn't find a solution!

Although when I use Jax alone or tensor-flow alone and monitor the GPU, I see that they are using the GPU properly! but from Trax just the ('tensor-flow numpy backend') is what using the GPU (memory and computation) but when I set the backend to ('jax') then just the memory is used without any computation!

Any help?

ashraf-kasem avatar Nov 24 '22 13:11 ashraf-kasem

Same problem for me, please post if you found a solution.

sushilks avatar Dec 02 '22 06:12 sushilks

setting "trax.fastmath.set_backend('tensorflow-numpy')" seems to help, I can see the gpu cycles being used.

sushilks avatar Dec 02 '22 07:12 sushilks