transformers
transformers copied to clipboard
[WIP] - Support generating with fallback for short form audio in Whisper
What does this PR do?
The aim of this PR is to refacto the Whisper generate
method to handle both short form and long form audio generation similarly. It will support short form audio generation with fallback (as requested in #29508).
I've been working on a first draft of what it would look like. Here's what I've done for now:
-
Removed the part of the code used for short form generation. Now when a short form of audio (or a batched short form of audio) is passed to
generate
it is processed by the part of the code previously used for long form generation. -
I still use a
is_shortform
parameter to distinguish between short form and long form audios. I've adapted parts of the code where we need to use this parameter:
--> _retrieve_max_frames_and_seek
needs to be adapted: if we are processing batched short form audios, we don't necessarily need the attention_masks.
--> In the short form generation, the start and end of each sequence contains the decoder_input_ids
and eos
tokens. I've made sure this is still the case with the new generate method.
--> I made sure we can still do short form generation when generation_config.no_timestamps_token_id
is not defined.
--> I made sure we can still do short form generation when logits_processor
is None.
- I've also adapted the code to make it compatible with
return_tokens_timestamps=True
.
I run the following snippet and compare the output I get with the old and new generate methods:
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", torch_dtype=torch.float16)
model = model.to("cuda")
# Batched short form audios:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:2]")
ds = dataset.select(range(2))[:2]['audio']
audios = [x["array"] for x in ds]
inputs = processor(audios, return_tensors="pt", truncation=False).to("cuda", torch.float16)
result = model.generate(**inputs, return_timestamps=False)
return_timestamps=False
and return_timestamps=True
and return_tokens_timestamps=True
will give the same outputs.
Next steps:
- We will get errors if
num_return_sequences>1
.
Who can review:
@sanchit-gandhi