transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add static cache support for Whisper

Open mobicham opened this issue 1 year ago • 8 comments

Feature request

Would be great to have static cache support for Whisper to make it faster with torch.compile. Currently, the generate() function doesn't support cache_implementation="static" for Whisper.

Motivation

Static cache with torch.compile can make generation much faster.

Your contribution

Static cache is already supported for LLMs and we see great speed-up.

mobicham avatar May 08 '24 11:05 mobicham

cc @sanchit-gandhi

amyeroberts avatar May 08 '24 16:05 amyeroberts

Let me try, I think I can make it, just need to patch https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L313 like llama model and pass cache_position should be ok

huseinzol05 avatar May 10 '24 07:05 huseinzol05

@huseinzol05 great, thanks ! I think you also need to make sure the model supports initializing the static cache via _setup_cache:

from transformers import StaticCache
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)

mobicham avatar May 10 '24 07:05 mobicham

I got hit by https://github.com/pytorch/pytorch/issues/123592 at https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L230, but the static cache is already working without torch compile from my local, arange should solved the problem

huseinzol05 avatar May 11 '24 07:05 huseinzol05

Maybe you can use arange instead like here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L964-L966

mobicham avatar May 11 '24 09:05 mobicham

Anything dynamic not possible, feed position_ids solved the problem, just like cache_position, i will push the initial later, so you can verify, the speedz is good

huseinzol05 avatar May 11 '24 11:05 huseinzol05

Great :+1: ! But that arange works well in Llama with fullgraph torch compile.

mobicham avatar May 11 '24 11:05 mobicham

https://github.com/huggingface/transformers/pull/30760

Compiled static cache able to achieve 186.26it/s while non-compiled got 150.20it/s

huseinzol05 avatar May 11 '24 14:05 huseinzol05

Closing as this is fixed: #31166 and #31772

ArthurZucker avatar Sep 06 '24 09:09 ArthurZucker