transformers
transformers copied to clipboard
[whisper] static kv cache
What does this PR do?
Supersedes https://github.com/huggingface/transformers/pull/28931 and extends it by adding static k/v cache support for Whisper. Also improves the performance of the eager attention implementation by removing un-necessary reshapes (inspired by LlamaAttention).
Similar to #28931, we use a separate cache for the self-attention and cross-attention layers. We define a lightweight EncoderDecoderCache
wrapper that holds these two cache classes and implements common base methods (e.g. to_legacy_cache()
) by calling the corresponding methods for each cache class.
However, there is one hurdle in enabling compatibility with torch.compile
. Namely, we have to determine whether we're in the first decoding step, or second step onwards:
- In the first decoding step, we compute the cross-attention k/v states and update the cache accordingly
- In the second step onwards, we re-use the k/v states directly from the cache. There’s no further update to the cross-attention cache, since the k/v states are derived entirely from the encoder hidden-states (which stay fixed)
=> the difficulty is in detecting whether we’re in the first decoding step (1), or second step onwards (2). With eager mode, we can condition on past_key_values.get_seq_length()
to determine the decoding step. However, for torch.compile
this introduces a graph break. Consequently, we add a boolean flag is_updated
to the StaticCache
class, which informs us whether the cache has been updated or not. The alternative would be to employ the same logic we do in the Flax code, where we re-compute the cross-attention k/v states each time. Benchmarks show this approach is 1.4x slower than adding the CPU flag.
Using the .generate
API with Whisper medium, we get approximately 5x speed-up when generating 64 tokens using sdpa attention. Note here that we compile the forward pass only:
bsz | dynamic tok/s | compiled tok/s | Speed-up |
---|---|---|---|
1 | 55.6 | 270.7 | 4.9 |
2 | 111.4 | 541.3 | 4.9 |
4 | 222.3 | 1078.8 | 4.9 |
8 | 446.3 | 2167.4 | 4.9 |
Extended results:
Whisper large-v3
bsz | dynamic tok/s | compiled tok/s | Speed-up |
---|---|---|---|
1 | 41.1 | 190.4 | 4.6 |
2 | 82.1 | 381.2 | 4.6 |
4 | 162.9 | 761.2 | 4.7 |
8 | 331.3 | 1522.5 | 4.6 |
Distil-Whisper distil-large-v3
bsz | dynamic tok/s | compiled tok/s | Speed-up |
---|---|---|---|
1 | 278.7 | 449.1 | 1.6 |
2 | 560.5 | 900.3 | 1.6 |
4 | 1113.2 | 1798.7 | 1.6 |
8 | 2225.0 | 3592.8 | 1.6 |
As expected, the speed-ups for Distil-Whisper are less pronounced:
- With only 2 decoder layers, the decoder forward pass is already >6x faster than Whisper, and we have a very small decoder graph that can be compiled
- The overhead from the logits post-processing now occupies a greater proportion of the generation time. Compiling the logits processors is a good next step for speeding-up generation further.
Code example:
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch
import logging
import time
torch._logging.set_logs(graph_breaks=True, recompiles=True)
torch_device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
model.to(torch_device, dtype=torch_dtype)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
inputs = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").to(torch_device)
input_features = inputs.input_features.to(torch_dtype)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model.generation_config.cache_implementation = "static"
# compile
for i in range(2):
model.generate(input_features)
# inference
pred_ids = model.generate(input_features)
In refactoring the eager attention implementation for the cache abstraction, I managed to remove a lot of wasteful .view
operations, generally aligning it with LLaMA and giving a performance boost even without compile (TODO: quantify speed-up).
The only regression comes when using FA2 and compile, where we have to introduce a bunch of new .transpose
operations for compatibility with the shape of our k/v cache (TODO: quantify regression). This is also a known problem in LLaMA.
There are a few tidy-up points left TODO. Once we're happy with the design, I'll complete the PR with the final checklist items:
- [x] Fix failing fast tests
- [x] Tidy docstrings for new arguments (
past_key_values
,cache_position
) - [x] Update model doc with FA2 usage
- [x] Run all Whisper slow tests
- [x] Run all ASR pipeline slow tests
- [ ] Check gradients propagate correctly when training with
output_attentions=True