transformers
transformers copied to clipboard
Whisper Prompting
Feature request
Add prompting for the Whisper model to control the style/formatting of the generated text.
Motivation
During training, Whisper can be fed a "previous context window" to condition on longer passages of text.
The original OpenAI Whisper implementation provides the user with the option of passing an initial_prompt to the model. This prompt is replaces the "previous context window" during inference.
By passing the prompt as the "previous context window", the Whisper model conditions its generation on whatever text is passed as the prompt. This allows the user to control aspects of the generation, such as spellings of named entities and punctuation formatting (see https://github.com/openai/whisper/discussions/963#discussioncomment-4987057).
This is possibly a cheaper way of adapting the Whisper model to specific decoding constraints than fine-tuning.
This notebook demonstrates prompting with the initial codebase, and explains how this can be achieved for HF's Whisper: https://colab.research.google.com/drive/14FSeaoRvgs5arOTfiMQBnQ5NaLyma7Tq?usp=sharing
The proposed API for prompting would look something as follows:
- Encode prompt text to prompt token ids (
processor.get_prompt_ids) - this method is a wrapper aroundprocessor.tokenizer.__call__that doesn't add the special token ids:
prompt = "IR, Newswire"
prompt_ids = processor.get_prompt_ids(prompt)
- Pass the input audio and prompt token ids to the
.generatemethod to get the predicted ids:
pred_ids = model.generate(input_features, prompt_ids=prompt_ids)
- Decode the predicted ids and 'slice' off the prompt (we can do this by passing the
prompt_ids):
pred_str = processor.batch_decode(pred_ids, prompt_ids=prompt_ids)
=> We would need to wrap all of this forced_decoder_ids logic into the generate method and update the processor/tokenizer accordingly.
Your contribution
Happy to guide the integration and review any PRs!
cc @hollance
Hello, I'd like to pick up this issue!
Hey @mollerup23! Super cool! We would first need to update the generate modelling code to slide the forced decoder ids as explained in the notebook:
https://github.com/huggingface/transformers/blob/d5de578c2227250d615f73a8fb88a5ce7f1743be/src/transformers/models/whisper/modeling_whisper.py#L1453
And then add a new method in the tokenizer to ignore the prompt ids. Does this sound good to you?
Hey @mollerup23 @sanchit-gandhi. Apologies, I'm not sure how picking these up works, I started working on it cause I saw there was no assignee and now have something I think is ready for review. Should I just keep it locally or push it up?
Totally fine with whatever, @mollerup23 commented first.
@connor-henderson @sanchit-gandhi I have not yet started on this issue, feel free to push your commits and pick it up!
I will continue to look into what @sanchit-gandhi mentioned in the meantime.
Sounds good, thanks
Closed via https://github.com/huggingface/transformers/pull/22496
Hi @sanchit-gandhi and @connor-henderson
I saw the PR, but I was wondering if we also integrated always_use_initial_prompt and condition_on_previous_text to the API? If no, is there any active work going towards it?
Thanks
Hey @romitjain - we're working on integrating the OpenAI Whisper algorithm into Transformers, which will provide more support for these fine-grained decoding parameters! c.f. #27492
Hey @romitjain - we're working on integrating the OpenAI Whisper algorithm into Transformers, which will provide more support for these fine-grained decoding parameters! c.f. #27492
are contribution allowed here? I'd like to help on that.