whisper-jax icon indicating copy to clipboard operation
whisper-jax copied to clipboard

Whisper JAX is not faster than Whisper in colab GPU environment.

Open bianxg opened this issue 1 year ago • 4 comments

Whisper JAX is not faster than Whisper in colab T4 GPU environment. Why? I tested with a 841 seconds long audio file. The Whisper JAX used 182 seconds and Whisper used only 148 seconds.( Both use small model)

Please reference the Whisper JAX test code: https://drive.google.com/file/d/1T9sGsOS4md5169jAnSpQX_tHGbS4yFEC/view?usp=sharing

bianxg avatar Oct 23 '23 12:10 bianxg

I have same question

r2d209git avatar Oct 31 '23 07:10 r2d209git

not only on colab but also on consumer hardware. I am able to run the whisper medium on my 8 VRAM GPU with no issue but using Whisper Jax i have no idea why i need to run it in dtype float16 to do not end up with OOM error. Is there a logic explanation for this ?

WasamiKirua avatar Nov 27 '23 11:11 WasamiKirua

Hey @bianxg - it looks like you're measuring the compilation time, which is supposed to be slow. Any subsequent calls to the pipeline will be much faster since we leverage the compiled function. You can see this in action in this Kaggle notebook: https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu

sanchit-gandhi avatar Dec 15 '23 11:12 sanchit-gandhi

@bianxg @r2d209git @WasamiKirua I think other than the compile time, the problem is that we are all looking at the number from the demo. That number is only one the forward pass which @sanchit-gandhi has updated. I'm getting the same number as the demo on TPUs however, where it gets really slow is the post processing step where [_decode_asr](https://github.com/huggingface/transformers/blob/28de2f4de3f7bf5dbe995237b039e95b446038c3/src/transformers/models/whisper/tokenization_whisper.py#L882) is called and that's just supper slow.

Has any one had luck getting better results?

@sanchit-gandhi Any tips on how that was optimized on the demo? I'm at around 11s where It feels like 1-3 seconds on the huggingface demo you have.

Some benchmarks on my side: on a T4 GPU

transcription 111.91071248054504
post-processing 12.171269178390503
Wall time: 2min 11s
on V3-8 TPU
transcription 3.624922275543213
post-processing 12.646348237991333
Wall time: 22.1 s

As you can see same audio, same post processing time, huge transcription boost!

Help on how to lower that 12seconds is much appreciated.

RezaTokhshid avatar Mar 15 '24 17:03 RezaTokhshid