whisper-jax
whisper-jax copied to clipboard
whisper-jax TPU is much slower than faster-whisper T4
I use whisper-jax in TPU for my speech translate project, I found that when it recognize a long sentence, it will cause 40s. But faster_whisper with a kaggle t4 only cost 5s-6s. Not like the power of TPU v3. Both of them i use float32, because the accurate is important for us.
Same here, it seems really slow for me too on a TPU.
Update: I managed to get it working. The first call to the API seems to take a long time as it caches "whatever" internally - it's not very clear. Once the first text generation has completed, all subsequent text generations occur in what I would describe as 70x speed.
The solution I came up with, was to play a test audio clip on initialisation.
Ok, i will test later.
Update: I managed to get it working. The first call to the API seems to take a long time as it caches "whatever" internally - it's not very clear. Once the first text generation has completed, all subsequent text generations occur in what I would describe as 70x speed.
The solution I came up with, was to play a test audio clip on initialisation.
It is in the notebooks.
We'll need to compile the pmap function the first time we use it. You can expect compilation to take ~2 minutes on a TPU v3-8 with a batch size of 16. Enough time to grab a coffee ☕️
It is not clear enough
@couldbejake @lparisi @ILG2021
I think other than the compile time, the problem is that we are all looking at the number from the demo. That number is only one the forward
pass which @sanchit-gandhi has updated. I'm getting the same number as the demo on TPUs however, where it gets really slow is the post processing step where [_decode_asr](https://github.com/huggingface/transformers/blob/28de2f4de3f7bf5dbe995237b039e95b446038c3/src/transformers/models/whisper/tokenization_whisper.py#L882)
is called and that's just supper slow.
Has any one had luck getting better results?
@sanchit-gandhi Any tips on how that was optimized on the demo? I'm at around 11s where It feels like 1-3 seconds on the huggingface demo you have.
Some benchmarks on my side: on a T4 GPU
transcription 111.91071248054504
post-processing 12.171269178390503
Wall time: 2min 11s
on V3-8 TPU
transcription 3.624922275543213
post-processing 12.646348237991333
Wall time: 22.1 s
As you can see same audio, same post processing time, huge transcription boost!
Help on how to lower that 12seconds is much appreciated.