whisper-jax
whisper-jax copied to clipboard
why whisper-jax did not use my GPU?
from whisper_jax import FlaxWhisperPipline import jax.numpy as jnp
instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.float16, batch_size=16)
text = pipeline("10m.mp3") print(text)
You might need to check your jax version.
I am use jax 0.4.19, it works.
!pip install jax==0.4.19 !pip install -U "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I tried to use windows11(settings/installed apps, NVIDIA driver 546.33, NVIDIA cuda11.8)(nvidia-smi, CUDA 12.3)-wsl2-Ubuntu-docker container-nvidia/cuda image, (https://hub.docker.com/r/nvidia/cuda/tags?page=2&name=11.8), no luck.
Do I need "jax[cuda12_pip]==0.4.19" or "jax[cuda11_pip]==0.4.19"? and how about container version?
You might need to check your jax version.
I am use jax 0.4.19, it works.
!pip install jax==0.4.19 !pip install -U "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
now, gpu got working. but it's slow. a mp3 of 10 minutes spent more than 300s.