transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[WIP] - Support generating with fallback for short form audio in Whisper

Open kamilakesbi opened this issue 9 months ago • 1 comments

What does this PR do?

The aim of this PR is to refacto the Whisper generate method to handle both short form and long form audio generation similarly. It will support short form audio generation with fallback (as requested in #29508).

I've been working on a first draft of what it would look like. Here's what I've done for now:

  • Removed the part of the code used for short form generation. Now when a short form of audio (or a batched short form of audio) is passed to generate it is processed by the part of the code previously used for long form generation.

  • I still use a is_shortform parameter to distinguish between short form and long form audios. I've adapted parts of the code where we need to use this parameter:

--> _retrieve_max_frames_and_seek needs to be adapted: if we are processing batched short form audios, we don't necessarily need the attention_masks.

--> In the short form generation, the start and end of each sequence contains the decoder_input_ids and eos tokens. I've made sure this is still the case with the new generate method.

--> I made sure we can still do short form generation when generation_config.no_timestamps_token_id is not defined.

--> I made sure we can still do short form generation when logits_processor is None.

  • I've also adapted the code to make it compatible with return_tokens_timestamps=True.

I run the following snippet and compare the output I get with the old and new generate methods:

from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch

processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", torch_dtype=torch.float16)
model = model.to("cuda")

# Batched short form audios: 
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:2]")
ds = dataset.select(range(2))[:2]['audio']
audios = [x["array"] for x in ds]
inputs = processor(audios, return_tensors="pt", truncation=False).to("cuda", torch.float16)

result = model.generate(**inputs, return_timestamps=False)

return_timestamps=False and return_timestamps=True and return_tokens_timestamps=True will give the same outputs.

Next steps:

  • We will get errors if num_return_sequences>1.

Who can review:

@sanchit-gandhi

kamilakesbi avatar May 23 '24 10:05 kamilakesbi