whisper-jax
whisper-jax copied to clipboard
Whisper JAX is not faster than Whisper in colab GPU environment.
Whisper JAX is not faster than Whisper in colab T4 GPU environment. Why? I tested with a 841 seconds long audio file. The Whisper JAX used 182 seconds and Whisper used only 148 seconds.( Both use small model)
Please reference the Whisper JAX test code: https://drive.google.com/file/d/1T9sGsOS4md5169jAnSpQX_tHGbS4yFEC/view?usp=sharing
I have same question
not only on colab but also on consumer hardware. I am able to run the whisper medium on my 8 VRAM GPU with no issue but using Whisper Jax i have no idea why i need to run it in dtype float16 to do not end up with OOM error. Is there a logic explanation for this ?
Hey @bianxg - it looks like you're measuring the compilation time, which is supposed to be slow. Any subsequent calls to the pipeline will be much faster since we leverage the compiled function. You can see this in action in this Kaggle notebook: https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu
@bianxg @r2d209git @WasamiKirua
I think other than the compile time, the problem is that we are all looking at the number from the demo. That number is only one the forward
pass which @sanchit-gandhi has updated. I'm getting the same number as the demo on TPUs however, where it gets really slow is the post processing step where [_decode_asr](https://github.com/huggingface/transformers/blob/28de2f4de3f7bf5dbe995237b039e95b446038c3/src/transformers/models/whisper/tokenization_whisper.py#L882)
is called and that's just supper slow.
Has any one had luck getting better results?
@sanchit-gandhi Any tips on how that was optimized on the demo? I'm at around 11s where It feels like 1-3 seconds on the huggingface demo you have.
Some benchmarks on my side: on a T4 GPU
transcription 111.91071248054504
post-processing 12.171269178390503
Wall time: 2min 11s
on V3-8 TPU
transcription 3.624922275543213
post-processing 12.646348237991333
Wall time: 22.1 s
As you can see same audio, same post processing time, huge transcription boost!
Help on how to lower that 12seconds is much appreciated.