Inefficient tokens suppression during BeamSearch
In the code, there appears to be an issue when performing beam-search after running the Whisper model. Specifically, during the search, there is a check that logits_processors is not empty. If it contains a filled SuppressTokens, the procedure to suppress corresponding logits is triggered by overriding their values to the minimum (through DisableTokens). This suppression procedure seems to involve numerous inefficient and redundant computations for suppressing indices within the batch. This might be a bottleneck preventing the processing of the next batch until the completion of the BeamSearch/GreedySearch procedure, especially in scenarios where suppressing specific tokens is necessary.
During model inference using the faster_whisper python wrapper, I observe the following pattern: I have a pre-defined list of approximately 2500 tokens that I do not want the model to select. I declare the model with suppression settings for these tokens. When I initiate inference on a new sample, I notice a brief spike in GPU activity during the processing of the first batch. Subsequently, there is an extended computation on a single CPU core (presumably during the calculation of suppressed logits indices). This is followed by another GPU activity spike during the processing of a new batch and a recurring pattern of monotonous computations. As a result, when performing inference with the suppression of specific tokens, I observe a sixfold increase in runtime, where one second out of six involves GPU computations, and the remaining 5 seconds involve the beam-search procedure with token suppression.
Any updates here?
I believe I'm running into the same issue as reported here. I see throughput decrease as I increase the batch size for a large file. The decrease is approximately linear in the batch size when I suppress tokens and use beam size 5.
I pass through the same set of 40 of sub-thirty second audio files, each padded to 30s. I'm running on an A100 8GB and I'm no where near maxing out the memory. I expect throughput to increase, i.e. total inference processing time to decrease with increased batch size, for the transcription of the 40 files. But what I see is:
batch size | inference time
1 | 16 seconds
2 | 15 seconds
8 | 45 seconds
16 | 78 seconds
32 | 158 seconds
Reducing the beam size to 2 improves the situation but the trend of decreasing throughput with larger batch size persists.
I think this PR has the potential to solve this issue. The results I posted in the PR description are from a test using a list of 845 suppression tokens.