Add static cache support for Whisper
Feature request
Would be great to have static cache support for Whisper to make it faster with torch.compile. Currently, the generate() function doesn't support cache_implementation="static" for Whisper.
Motivation
Static cache with torch.compile can make generation much faster.
Your contribution
Static cache is already supported for LLMs and we see great speed-up.
cc @sanchit-gandhi
Let me try, I think I can make it, just need to patch https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L313 like llama model and pass cache_position should be ok
@huseinzol05 great, thanks ! I think you also need to make sure the model supports initializing the static cache via _setup_cache:
from transformers import StaticCache
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
I got hit by https://github.com/pytorch/pytorch/issues/123592 at https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L230, but the static cache is already working without torch compile from my local, arange should solved the problem
Maybe you can use arange instead like here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L964-L966
Anything dynamic not possible, feed position_ids solved the problem, just like cache_position, i will push the initial later, so you can verify, the speedz is good
Great :+1: ! But that arange works well in Llama with fullgraph torch compile.
https://github.com/huggingface/transformers/pull/30760
Compiled static cache able to achieve 186.26it/s while non-compiled got 150.20it/s
Closing as this is fixed: #31166 and #31772