mlx-examples
mlx-examples copied to clipboard
[Whisper] Add word timestamps and confidence scores
Hi @awni 👋
I've tried to add several new features to the Whisper implementation through this PR, following the implementation of the original repository:
- word-level timestamps (https://github.com/openai/whisper/pull/869)
- word-level confidence scores (https://github.com/openai/whisper/commit/5fa43566f00a3e337f7fb481a2b962118453a96b)
- clip_timestamps and hallucination_silence_threshold (https://github.com/openai/whisper/pull/1838)
This is still a draft version that may require some optimizations:
- Move certain numpy operations to MLX. I encountered issues running some operations in MLX, so I left them in np. However, you may have better solutions :)
- More efficient implementation of
median_filter
anddtw
. I used directly themedian_filter
from scipy, since I didn't find theunfold
function in mlx. As fordtw
, I kept the original numba version - Better handling of
qk
attention scores in the model forward
Below are the benchmark times from tests run on my M1 Pro.
Feature time 0.038
Model: TINY
Model forward time 0.038
Decode time 0.211
Everything time 0.266
Everything (w/ word_timestamps) time 0.320
--------------------------------------------------
Model: SMALL
Model forward time 0.233
Decode time 0.644
Everything time 0.859
Everything (w/ word_timestamps) time 1.113
--------------------------------------------------
Model: MEDIUM
Model forward time 0.684
Decode time 1.700
Everything time 2.356
Everything (w/ word_timestamps) time 2.914
--------------------------------------------------
Model: LARGE
Model forward time 2.782
Decode time 2.701
Everything time 3.597
Everything (w/ word_timestamps) time 4.823
--------------------------------------------------
Super cool, thanks for adding that!
Addresses #146
After measuring the time taken for operations to add word-level timestamps/scores, I've found that most are consumed by the extra model forward pass. There also appears to be overhead in the first run of DTW, likely due to Numba JIT compilation
Below are the measured times from tests with the large model.
extra forward time: 1.2198s
median_filter time: 0.0046s
dtw time: 0.7094s
extra forward time: 1.2341s
median_filter time: 0.0044s
dtw time: 0.0012s
extra forward time: 1.2124s
median_filter time: 0.0045s
dtw time: 0.0010s
extra forward time: 1.3064s
median_filter time: 0.0081s
dtw time: 0.0005s
extra forward time: 1.2168s
median_filter time: 0.0045s
dtw time: 0.0003s
Hi @awni, thanks for the review!
I've just done a rebase and added a test for word-level timestamps & confidence, comparing the results with those from openai-whisper.
Below are the new measured times from tests run on my mac m1 pro:
Selected models: ['tiny', 'small', 'medium', 'large-v3']
Feature time 0.035
Model: TINY
Model forward time 0.034
Decode time 0.186
Everything time 0.251
Everything (w/ word_timestamps) time 0.291
--------------------------------------------------
Model: SMALL
Model forward time 0.221
Decode time 0.650
Everything time 0.855
Everything (w/ word_timestamps) time 1.137
--------------------------------------------------
Model: MEDIUM
Model forward time 0.646
Decode time 1.559
Everything time 2.176
Everything (w/ word_timestamps) time 2.832
--------------------------------------------------
Model: LARGE-V3
Model forward time 1.209
Decode time 2.753
Everything time 3.609
Everything (w/ word_timestamps) time 4.953
--------------------------------------------------