whisper-jax icon indicating copy to clipboard operation
whisper-jax copied to clipboard

CUDA out of memory

Open Tronic opened this issue 1 year ago • 6 comments

Trying to load medium or large model, I get out of memory errors. Loading small with float16 precision works but takes all my 24 GB VRAM. Is there any way to limit Jax memory usage? The OpenAI model is far more modest in its requirements. Reducing the model weights to float16 should be a good idea too.

Tronic avatar Apr 20 '23 22:04 Tronic

See related: https://github.com/huggingface/transformers/issues/22224

sanchit-gandhi avatar Apr 21 '23 11:04 sanchit-gandhi

You can also convert the parameters to float16/bfloat16 as follows:

# for fp16
pipeline.params = pipeline.model.to_fp16(pipeline.params)
# for bf16
pipeline.params = pipeline.model.to_bf16(pipeline.params)

sanchit-gandhi avatar Apr 21 '23 17:04 sanchit-gandhi

@sanchit-gandhi It is a bit concerning that it can take up to 30+ gbs of GPU memory during batch inference. How much batch size will be ideal to keep usage low? Like under 12gb VRAM

arnavmehta7 avatar Apr 21 '23 18:04 arnavmehta7

I tried running the medium model on a T4 colab instance. Took 14mins to transcribe a 10min audio. Is this due to the memory constraints and the model paging out? Or is it running on the CPU altogether?

seboslaw avatar Apr 24 '23 11:04 seboslaw

I get this error after updating the video card drivers or kernel and forgetting to reboot afterwards. You can use GreenWithEnvy (gwe), available in most distro repos, to profile Nvidia cards and see what, if anything, is going on there. Update: gwe seems like a bloated version of nvidia-smi, which comes with the video drivers already, so just use that.

themanyone avatar Apr 27 '23 12:04 themanyone

Note that the phenomenon of JAX using 90% of your GPU memory just to load the model is due to JAX's GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

JAX doesn't actually require all of this memory, but blocks it out to prevent fragmentation.

If you want to disable this, you can do so with the global var XLA_PYTHON_CLIENT_PREALLOCATE:

XLA_PYTHON_CLIENT_PREALLOCATE=false python run_benchmark.py

A more reliable way of monitoring your JAX memory is jax-smi: https://github.com/ayaka14732/jax-smi

Still working on figuring out how we can load the large-v2 checkpoint on a 16 GB T4 GPU!

sanchit-gandhi avatar May 02 '23 09:05 sanchit-gandhi