whisper-jax icon indicating copy to clipboard operation
whisper-jax copied to clipboard

why whisper-jax did not use my GPU?

Open bk111 opened this issue 1 year ago • 3 comments

from whisper_jax import FlaxWhisperPipline import jax.numpy as jnp

instantiate pipeline in bfloat16

pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.float16, batch_size=16)

text = pipeline("10m.mp3") print(text)

bk111 avatar Jan 03 '24 06:01 bk111

You might need to check your jax version.

I am use jax 0.4.19, it works.

!pip install jax==0.4.19 !pip install -U "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

nftblackmagic avatar Jan 05 '24 16:01 nftblackmagic

I tried to use windows11(settings/installed apps, NVIDIA driver 546.33, NVIDIA cuda11.8)(nvidia-smi, CUDA 12.3)-wsl2-Ubuntu-docker container-nvidia/cuda image, (https://hub.docker.com/r/nvidia/cuda/tags?page=2&name=11.8), no luck.

Do I need "jax[cuda12_pip]==0.4.19" or "jax[cuda11_pip]==0.4.19"? and how about container version?

bk111 avatar Jan 06 '24 02:01 bk111

You might need to check your jax version.

I am use jax 0.4.19, it works.

!pip install jax==0.4.19 !pip install -U "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

now, gpu got working. but it's slow. a mp3 of 10 minutes spent more than 300s.

bk111 avatar Jan 07 '24 13:01 bk111