transformers icon indicating copy to clipboard operation
transformers copied to clipboard

forced_decoder_ids in Whisper models significantly impacts performance, use decoder_input_ids instead

Open tonysimpson opened this issue 2 years ago • 9 comments

Feature request

@ArthurZucker probably one for you based on commit logs.

Using forced_decoder_ids to provide "prompt" and or "prefix" to the whisper model is very inefficient as a forward pass and sampling is done for each token in the forced_decoder_ids but the result is already known. Instead the model parameter decoder_input_ids could be used which only uses one forward pass to initialise the kv cache with all the input tokens and immediately is sampling useful next tokens.

Openai's whisper limits prompt to half the context length (448 // 2 - 1 = 223) , so if you want to use transformers whisper to behave like openai's whisper and you expect 20 words + EOS in your input feature then forward pass counts are:

  • transformers: 244
  • openai-whisper: 21

I'm raising this as a feature request rather than a bug or PR as I think forced_decoder_ids is already pretty well embedded in the code and the community so I assume it can't just be ripped out and a discussion is probably required before a PR.

Here's some code that demonstrates the issue in IPython:

from transformers import (
    WhisperForConditionalGeneration,
    WhisperTokenizerFast,
    WhisperFeatureExtractor,
)
from datasets import load_dataset
import torch
feature_extractor = WhisperFeatureExtractor()
tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-tiny.en", language="english")
# Patch WhisperForConditionalGeneration._prepare_decoder_input_ids_for_generation because the one on GenerationMixin doesn't handle whisper properly.
def prepare_decoder_input_ids_for_generation_patch(self, batch_size, model_input_name, model_kwargs, decoder_start_token_id, bos_token_id, device):
    if 'decoder_input_ids' not in model_kwargs:
        return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id, model_kwargs
    else:
        return model_kwargs.pop('decoder_input_ids'), model_kwargs
WhisperForConditionalGeneration._prepare_decoder_input_ids_for_generation = prepare_decoder_input_ids_for_generation_patch
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
audio = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")[3]["audio"]["array"]
input_features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_features

# A custom logits processor to show how many times the forward pass and sample are run
def logits_processor_count_factory():
    count = 0
    def logits_processor_count(input_ids, scores):
        nonlocal count
        count += 1
        print(count)
        return scores
    return logits_processor_count

PREV_TOKEN = 50360 # <|startofprev|>
prompt_tokens = [PREV_TOKEN, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284, 7062, 465, 21443, 13, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236, 878, 514, 11, 985, 2915, 7428, 422, 6600, 290, 663, 2482, 3051, 749, 14704, 284, 262, 2000, 13]
# note prompt_ids is prefixed to forced_decoder_ids inside generate
# counts to 106
forced_decoder_ids_output = model.generate(input_features=input_features, return_timestamps=False, prompt_ids=torch.LongTensor(prompt_tokens), logits_processor=[logits_processor_count_factory()])[0]
print(tokenizer.decode(forced_decoder_ids_output, decode_with_timestamps=False))

SOT_TOKEN = 50257 # <|startoftranscript|>
NO_TIMESTAMPS_TOKEN = 50362 # <|notimestamps|>
decoder_input_ids = torch.LongTensor([prompt_tokens + [SOT_TOKEN, NO_TIMESTAMPS_TOKEN]])
# counts to 31
decoder_input_ids_output = model.generate(input_features=input_features, return_timestamps=False, forced_decoder_ids=None, begin_suppress_tokens=None, decoder_input_ids=decoder_input_ids, logits_processor=[logits_processor_count_factory()])[0]
print(tokenizer.decode(decoder_input_ids_output, decode_with_timestamps=False))

You can get performance for bothing in IPython doing:

%timeit model.generate(input_features=input_features, return_timestamps=False, prompt_ids=torch.LongTensor(prompt_tokens))[0]
%timeit model.generate(input_features=input_features, return_timestamps=False, forced_decoder_ids=None, begin_suppress_tokens=None, decoder_input_ids=decoder_input_ids)[0]

On CPU for me using decoder_input_ids is 2x faster with this input.

Motivation

I want to be able to use the transformers implementation of whisper in a production system where cost and processing time will be critical, due to the way we are using whisper this issue impact performance a lot more than the 2x I quoted above, its more like 5x in our use case. Obviously we can code around it but if it's possible to change transformers and avoid custom code I'd prefer that.

Your contribution

I'd be able to create a PR but without knowing more about how the maintainers would like to handle backward compatibility etc I don't think its the right place to start.

I'd be very happy to be involved in a discussion, offer opinions or testing etc.

tonysimpson avatar May 29 '23 16:05 tonysimpson

Hey! Thanks for taking the time to open this PR. Totally get the speedup and the latency induced by the use of foced_decoder_ids rather than decoder_input_ids. The addition of the prompt_ids was mostly handled by @hollance, which will be able to have a better look at this. I don't think that there was a release yet, which means this can still be changeable (if its not impossible to update)

ArthurZucker avatar May 30 '23 09:05 ArthurZucker

IIRC we decided for the time being to keep using forced_decoder_ids for the prompts, even though it's slower indeed. Would be nice to improve this.

hollance avatar May 30 '23 16:05 hollance

What might a path to improvement look like? A PR to make sure passing in a custom decoder_input_ids works correctly might be a good start? Happy to do that. I know it doesn't work for PT as the <|startoftranscript|> token can get added by GenerationMixin in the wrong place, I haven't tried TF or flax.

tonysimpson avatar May 31 '23 08:05 tonysimpson

I don't understand this part of the generation process well enough yet to say anything useful about it. You'd think that we could start generation by passing in the entire forced_decoder_ids as the decoder_input_ids as the first step, rather than doing it one token at a time. The ForceTokensLogitsProcessor also plays a part in this.

@Narsil can probably enlighten us 😄

hollance avatar May 31 '23 09:05 hollance

@hollance Yes we could absolutely convert forced_decoder_ids to decoder_input_ids in .generate(...), and I think we can do it in a way that doesn't break anyones code. I can put a draft PR together for the PT code probably sometime tomorrow.

tonysimpson avatar May 31 '23 13:05 tonysimpson

Hi, not sure if I can enlighten.

In general, I'm not sure why forced_decoder_ids is useful for, since if you know what ids you should get, there's no need to do inference.

If it was added, the general caution is that it must have been useful for some reason at some point, but in this specific use case I don't really understand.

Narsil avatar Jun 05 '23 15:06 Narsil

@Narsil For Whisper, we want to start generation not with a single "BOS" token (here, <|startoftranscript|>) but with several tokens. In the case of prompting, this could be a fairly long sequence of tokens. For example <|startofprev|> here is the prompt <|startoftranscript|><|en|><|notimestamps|>. The prompt text is used to prime the model with more context. Right now, we use forced_decoder_ids to feed in this sequence of "starting tokens", which means they get processed one-by-one in the generation loop. It's more efficient to allow the first step of generation to process this entire sequence at once.

hollance avatar Jun 06 '23 09:06 hollance

Yes, I know. I don't think it's necessary but I just usually give the benefit of the doubt when something was coded intentionally.

Narsil avatar Jun 06 '23 11:06 Narsil

Hello every one, what if we simply specify decoder_input_ids as an argument to generate call?

  generated_ids = self.model.generate(
      inputs=input_features,
      decoder_input_ids=torch.tensor(
          [decoder_ids], dtype=torch.long
      ),
  ).cpu()

As I understood it will be used here

DavraYoung avatar Jul 26 '23 19:07 DavraYoung

Hii, I'm trying to run the ONNX model, when i'm exporting the onnx model using optimum-cli_, i'm getting 4 onnx model decoder_model,decoder_model_merged,decoder_with_past_model and encoder_model.

Can anyone please help me how to predict using these 4 models? The encoder model is giving 1 output that is encoder_hidden_state(1,1500,384) but on the other hand normal decoder_model is taking 2 input-> one is encoder_hidden_state and another one is decoder_input_ids, i've tried with multiple decoder_ids but still i'm not getting correct output.

Can Anyone please suggest what is the correct decoder_input_ids that i need to give to the model? Thanks in Advance.

SahinMjks avatar Nov 21 '23 08:11 SahinMjks

Resolved in https://github.com/huggingface/transformers/pull/28687.

sanchit-gandhi avatar Mar 28 '24 14:03 sanchit-gandhi