whisper-jax icon indicating copy to clipboard operation
whisper-jax copied to clipboard

whisper-jax TPU is much slower than faster-whisper T4

Open ILG2021 opened this issue 1 year ago • 6 comments

I use whisper-jax in TPU for my speech translate project, I found that when it recognize a long sentence, it will cause 40s. But faster_whisper with a kaggle t4 only cost 5s-6s. Not like the power of TPU v3. Both of them i use float32, because the accurate is important for us.

ILG2021 avatar Aug 24 '23 14:08 ILG2021

Same here, it seems really slow for me too on a TPU.

couldbejake avatar Sep 11 '23 17:09 couldbejake

Update: I managed to get it working. The first call to the API seems to take a long time as it caches "whatever" internally - it's not very clear. Once the first text generation has completed, all subsequent text generations occur in what I would describe as 70x speed.

The solution I came up with, was to play a test audio clip on initialisation.

couldbejake avatar Sep 11 '23 22:09 couldbejake

Ok, i will test later.

ILG2021 avatar Sep 11 '23 23:09 ILG2021

Update: I managed to get it working. The first call to the API seems to take a long time as it caches "whatever" internally - it's not very clear. Once the first text generation has completed, all subsequent text generations occur in what I would describe as 70x speed.

The solution I came up with, was to play a test audio clip on initialisation.

It is in the notebooks.

We'll need to compile the pmap function the first time we use it. You can expect compilation to take ~2 minutes on a TPU v3-8 with a batch size of 16. Enough time to grab a coffee ☕️

lparisi avatar Oct 04 '23 16:10 lparisi

It is not clear enough

couldbejake avatar Oct 04 '23 16:10 couldbejake

@couldbejake @lparisi @ILG2021 I think other than the compile time, the problem is that we are all looking at the number from the demo. That number is only one the forward pass which @sanchit-gandhi has updated. I'm getting the same number as the demo on TPUs however, where it gets really slow is the post processing step where [_decode_asr](https://github.com/huggingface/transformers/blob/28de2f4de3f7bf5dbe995237b039e95b446038c3/src/transformers/models/whisper/tokenization_whisper.py#L882) is called and that's just supper slow.

Has any one had luck getting better results?

@sanchit-gandhi Any tips on how that was optimized on the demo? I'm at around 11s where It feels like 1-3 seconds on the huggingface demo you have.

Some benchmarks on my side: on a T4 GPU

transcription 111.91071248054504
post-processing 12.171269178390503
Wall time: 2min 11s
on V3-8 TPU
transcription 3.624922275543213
post-processing 12.646348237991333
Wall time: 22.1 s

As you can see same audio, same post processing time, huge transcription boost!

Help on how to lower that 12seconds is much appreciated.

RezaTokhshid avatar Mar 15 '24 17:03 RezaTokhshid