WhisperKit icon indicating copy to clipboard operation
WhisperKit copied to clipboard

Added TimestampRulesFilter implementation

Open jkrukowski opened this issue 11 months ago • 1 comments

This PR adds implementation for TimestampRulesFilter. The implementation is based on https://github.com/openai/whisper/blob/master/whisper/decoding.py#L441

Couple of questions here @ZachNagengast:

  • sampleBegin param passed to TimestampRulesFilter is 0, I think it might be incorrect. I compared it to the python implementation from the OpenAI repo and there this param is always greater or equal than 3 (and this makes sense, first 3 tokens are special tokens: 50258, 50259 and 50359 and AFAIK we don't want to supress them). If you run this code as is, some segments might be omited (because of the sampleBegin is 0, if you change it to 3, it should be ok).
  • this implementation slows down the whole inference code, maybe you have some ideas how to optimize it?
  • you mentioned that is has duplicated logic with SegmentSeeker, but I don't see it (AFAIK TimestampRulesFilter just supresses the token probabilities, while SegmentSeeker creates the whole segments). Could you please clarify?

jkrukowski avatar Mar 04 '24 13:03 jkrukowski

@jkrukowski I push a small commit to measure the logit filtering time, here is what I'm getting for tiny with and without these new timestamp rules on the jfk.wav file: With: [WhisperKit] - Logit Filtering: 192.41 ms / 28 runs ( 6.87 ms/run) 37.78% Without: [WhisperKit] - Logit Filtering: 0.07 ms / 28 runs ( 0.00 ms/run) 0.02%

This is a bit high, it becomes especially noticeable with the tiny model. Something that is interesting is that only the first and last few tokens are slow (graph by chatgpt). This is for the jfk.wav

image Hopefully this gives you some guidance on where to look for optimizations. And the majority of the slowdown is in this block of code:

            // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
            let sampledTokens = tokens[sampleBegin...]
            let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin
            let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin
            if lastWasTimestamp {
                if penultimateWasTimestamp {
                    // has to be non-timestamp
                    logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity)
                } else {
                    // cannot be normal text tokens
                    logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity)
                }
            }

ZachNagengast avatar Mar 13 '24 00:03 ZachNagengast

@jkrukowski I push a small commit to measure the logit filtering time, here is what I'm getting for tiny with and without these new timestamp rules on the jfk.wav file: With: [WhisperKit] - Logit Filtering: 192.41 ms / 28 runs ( 6.87 ms/run) 37.78% Without: [WhisperKit] - Logit Filtering: 0.07 ms / 28 runs ( 0.00 ms/run) 0.02%

This is a bit high, it becomes especially noticeable with the tiny model. Something that is interesting is that only the first and last few tokens are slow (graph by chatgpt). This is for the jfk.wav

image Hopefully this gives you some guidance on where to look for optimizations. And the majority of the slowdown is in this block of code:

            // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
            let sampledTokens = tokens[sampleBegin...]
            let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin
            let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin
            if lastWasTimestamp {
                if penultimateWasTimestamp {
                    // has to be non-timestamp
                    logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity)
                } else {
                    // cannot be normal text tokens
                    logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity)
                }
            }

@ZachNagengast I've added more performant version of fillLastDimension function, seems like it's doing better, this is what I get for the release build on the jfk.wav file:

[WhisperKit] ---- Transcription Timings ----
[WhisperKit] Audio Load:              2.33 ms /      1 runs (    2.33 ms/run)  0.66%
[WhisperKit] Audio Processing:        0.11 ms /      1 runs (    0.11 ms/run)  0.03%
[WhisperKit] Mels:                   35.53 ms /      1 runs (   35.53 ms/run) 10.11%
[WhisperKit] Encoding:               13.39 ms /      1 runs (   13.39 ms/run)  3.81%
[WhisperKit] Matrices Init:           0.22 ms /      1 runs (    0.22 ms/run)  0.06%
[WhisperKit] Prefill:                 0.00 ms /      1 runs (    0.00 ms/run)  0.00%
[WhisperKit] Decoding:              239.40 ms /     28 runs (    8.55 ms/run) 68.15%
[WhisperKit] Non-inference:          61.25 ms /     28 runs (    2.19 ms/run) 17.43%
[WhisperKit] - Logit Filtering:       3.24 ms /     28 runs (    0.12 ms/run)  0.92%
[WhisperKit] - Sampling:             14.17 ms /     28 runs (    0.51 ms/run)  4.03%
[WhisperKit] - Kv Caching:            2.79 ms /     28 runs (    0.10 ms/run)  0.80%
[WhisperKit] - Word Timestamps:       0.00 ms /      0 runs (    0.00 ms/run)  0.00%
[WhisperKit] - Windowing:             0.08 ms /      1 runs (    0.08 ms/run)  0.02%
[WhisperKit] Fallbacks:               0.00 ms /      0 runs (    0.00 ms/run)  0.00%
[WhisperKit] Decoding Full Loop:    351.06 ms /     28 runs (   12.54 ms/run) 99.93%

jkrukowski avatar Mar 19 '24 16:03 jkrukowski

Much better! This looks in line with what I was seeing for those faster middle tokens previously. Think this is ready to come out of draft now?

ZachNagengast avatar Mar 19 '24 17:03 ZachNagengast

Much better! This looks in line with what I was seeing for those faster middle tokens previously. Think this is ready to come out of draft now?

good to hear this, 2 things are left:

  1. self.sampleBegin = 3 // FIXME: it should not be hardcoded value -- not sure what value should I put there
  2. force unwrapping in sumOfProbabilityOverTimestampsIsAboveAnyOtherToken maybe we should not force unwrap and return false gracefully, wdyt?

jkrukowski avatar Mar 19 '24 17:03 jkrukowski

  1. self.sampleBegin = 3 // FIXME: it should not be hardcoded value -- not sure what value should I put there

PrefilledIndex is already being passed into this function, but I think actually it should use intialPromptIndex. A good test to add for accuracy on this would be similar to this one https://github.com/argmaxinc/WhisperKit/blob/e45dc0a056197c4a4ee3dabe9c604f48b150e519/Tests/WhisperKitTests/UnitTests.swift#L314 where you'd create a bunch of options that change this initialPromptIndex and make sure it's working properly.

  1. force unwrapping in sumOfProbabilityOverTimestampsIsAboveAnyOtherToken maybe we should not force unwrap and return false gracefully, wdyt?

Besides the verbosity I think it's ok. If you want to be extra safe, you can wrap that whole part in a do catch and log an error similar to the sampling code. I'm not sure all the scenarios where BNNS will throw, but returning false would just fallback to default behavior so no issues there.

ZachNagengast avatar Mar 19 '24 17:03 ZachNagengast