WhisperKit
WhisperKit copied to clipboard
Added TimestampRulesFilter implementation
This PR adds implementation for TimestampRulesFilter
. The implementation is based on https://github.com/openai/whisper/blob/master/whisper/decoding.py#L441
Couple of questions here @ZachNagengast:
-
sampleBegin
param passed toTimestampRulesFilter
is 0, I think it might be incorrect. I compared it to the python implementation from the OpenAI repo and there this param is always greater or equal than 3 (and this makes sense, first 3 tokens are special tokens: 50258, 50259 and 50359 and AFAIK we don't want to supress them). If you run this code as is, some segments might be omited (because of thesampleBegin
is 0, if you change it to 3, it should be ok). - this implementation slows down the whole inference code, maybe you have some ideas how to optimize it?
- you mentioned that is has duplicated logic with
SegmentSeeker
, but I don't see it (AFAIKTimestampRulesFilter
just supresses the token probabilities, whileSegmentSeeker
creates the whole segments). Could you please clarify?
@jkrukowski I push a small commit to measure the logit filtering time, here is what I'm getting for tiny with and without these new timestamp rules on the jfk.wav
file:
With:
[WhisperKit] - Logit Filtering: 192.41 ms / 28 runs ( 6.87 ms/run) 37.78%
Without:
[WhisperKit] - Logit Filtering: 0.07 ms / 28 runs ( 0.00 ms/run) 0.02%
This is a bit high, it becomes especially noticeable with the tiny model. Something that is interesting is that only the first and last few tokens are slow (graph by chatgpt). This is for the jfk.wav
Hopefully this gives you some guidance on where to look for optimizations. And the majority of the slowdown is in this block of code:
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
let sampledTokens = tokens[sampleBegin...]
let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin
let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin
if lastWasTimestamp {
if penultimateWasTimestamp {
// has to be non-timestamp
logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity)
} else {
// cannot be normal text tokens
logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity)
}
}
@jkrukowski I push a small commit to measure the logit filtering time, here is what I'm getting for tiny with and without these new timestamp rules on the
jfk.wav
file: With:[WhisperKit] - Logit Filtering: 192.41 ms / 28 runs ( 6.87 ms/run) 37.78%
Without:[WhisperKit] - Logit Filtering: 0.07 ms / 28 runs ( 0.00 ms/run) 0.02%
This is a bit high, it becomes especially noticeable with the tiny model. Something that is interesting is that only the first and last few tokens are slow (graph by chatgpt). This is for the jfk.wav
Hopefully this gives you some guidance on where to look for optimizations. And the majority of the slowdown is in this block of code:
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly let sampledTokens = tokens[sampleBegin...] let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin if lastWasTimestamp { if penultimateWasTimestamp { // has to be non-timestamp logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity) } else { // cannot be normal text tokens logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity) } }
@ZachNagengast I've added more performant version of fillLastDimension
function, seems like it's doing better, this is what I get for the release build on the jfk.wav
file:
[WhisperKit] ---- Transcription Timings ----
[WhisperKit] Audio Load: 2.33 ms / 1 runs ( 2.33 ms/run) 0.66%
[WhisperKit] Audio Processing: 0.11 ms / 1 runs ( 0.11 ms/run) 0.03%
[WhisperKit] Mels: 35.53 ms / 1 runs ( 35.53 ms/run) 10.11%
[WhisperKit] Encoding: 13.39 ms / 1 runs ( 13.39 ms/run) 3.81%
[WhisperKit] Matrices Init: 0.22 ms / 1 runs ( 0.22 ms/run) 0.06%
[WhisperKit] Prefill: 0.00 ms / 1 runs ( 0.00 ms/run) 0.00%
[WhisperKit] Decoding: 239.40 ms / 28 runs ( 8.55 ms/run) 68.15%
[WhisperKit] Non-inference: 61.25 ms / 28 runs ( 2.19 ms/run) 17.43%
[WhisperKit] - Logit Filtering: 3.24 ms / 28 runs ( 0.12 ms/run) 0.92%
[WhisperKit] - Sampling: 14.17 ms / 28 runs ( 0.51 ms/run) 4.03%
[WhisperKit] - Kv Caching: 2.79 ms / 28 runs ( 0.10 ms/run) 0.80%
[WhisperKit] - Word Timestamps: 0.00 ms / 0 runs ( 0.00 ms/run) 0.00%
[WhisperKit] - Windowing: 0.08 ms / 1 runs ( 0.08 ms/run) 0.02%
[WhisperKit] Fallbacks: 0.00 ms / 0 runs ( 0.00 ms/run) 0.00%
[WhisperKit] Decoding Full Loop: 351.06 ms / 28 runs ( 12.54 ms/run) 99.93%
Much better! This looks in line with what I was seeing for those faster middle tokens previously. Think this is ready to come out of draft now?
Much better! This looks in line with what I was seeing for those faster middle tokens previously. Think this is ready to come out of draft now?
good to hear this, 2 things are left:
-
self.sampleBegin = 3 // FIXME: it should not be hardcoded value
-- not sure what value should I put there - force unwrapping in
sumOfProbabilityOverTimestampsIsAboveAnyOtherToken
maybe we should not force unwrap and return false gracefully, wdyt?
- self.sampleBegin = 3 // FIXME: it should not be hardcoded value -- not sure what value should I put there
PrefilledIndex is already being passed into this function, but I think actually it should use intialPromptIndex
. A good test to add for accuracy on this would be similar to this one https://github.com/argmaxinc/WhisperKit/blob/e45dc0a056197c4a4ee3dabe9c604f48b150e519/Tests/WhisperKitTests/UnitTests.swift#L314 where you'd create a bunch of options that change this initialPromptIndex and make sure it's working properly.
- force unwrapping in sumOfProbabilityOverTimestampsIsAboveAnyOtherToken maybe we should not force unwrap and return false gracefully, wdyt?
Besides the verbosity I think it's ok. If you want to be extra safe, you can wrap that whole part in a do catch and log an error similar to the sampling code. I'm not sure all the scenarios where BNNS will throw, but returning false would just fallback to default behavior so no issues there.