transformers
transformers copied to clipboard
Add Flax Whisper for audio classification
Feature request
The PR https://github.com/huggingface/transformers/pull/21754 adds the PyTorch version of WhisperForAudioClassification. It would be great to add the Flax equivalent for cross-library equivalence ♻️
Motivation
Whisper is an encoder-decoder model for speech recognition. However, we can repurpose the model for other speech tasks, such as audio classification.
Audio classification is the task of mapping from an input speech sequence to a single class prediction. For more details, refer to the task page on the Hub: https://huggingface.co/tasks/audio-classification
For audio classification, we only require a single model output. Thus, we do not need the auto-regressive generation capacities of the Whisper decoder (which is used to generate a sequence of text tokens during speech recognition). Instead, we can just use the Whisper encoder to get hidden states, and add a classification head on top to make class label predictions.
This is analogous to using a Wav2Vec2 model for audio classification: the Wav2Vec2 encoder is used to get hidden states, and a classification head added on top to make class label predictions.
The PR https://github.com/huggingface/transformers/pull/21754 adds the PyTorch version of WhisperForAudioClassification. It required adding a projection layer and classification layer on top of the WhisperEncoder. For more details, refer directly to the pull request.
It would be great to add the Flax equivalent of this model for cross-framework support.
The most difficult part of this PR will be getting the model tester to work. You can see from the PyTorch PR that we require a standalone tester for the audio classification model. This is because the original Whisper model is an encoder-decoder model, but the audio classification model is an encoder-only model. Thus, we require different testing logic.
Your contribution
Opening this one up to the community! This will be quite a fun JAX/Flax PR! 🚀
If you're interested in tackling this, free to drop a comment in this thread and open a PR when you're ready. More than happy to answer any questions / queries about this integration!
Hi. It's my first time contributing to open source. I want to tackle this issue. How can I get started?
I have contributed to a few good first issues on HF, would like to take this to learn JAX if available!
@Potato-Cracker , @yhl48 Are you guys currently working on it? I have a working branch locally with passing tests, but if you guys would like to make PR, that's totally cool too.
@Shubhamai please go ahead with the PR!
Uh, looks like a PR is already submitted :smile: , I will see if I can assist the linked PR.
Very cool that there's so much interest in adding Flax models! Great to see that the JAX/Flax community is so active 🙌 Would you guys be interested in finding other PyTorch models to port to JAX/Flax in transformers
?
@sanchit-gandhi I will be happy to contribute to some of them, would be great if you have any suggestions on any particular models!
Very cool! You can take a look at the model integration table here: https://github.com/huggingface/transformers/blob/main/docs/source/en/index.mdx#supported-frameworks
There are a bunch of popular models that are supported in PyTorch but not Flax, LLaMa being one of them! This could be a cool model addition if you're interested?
I would love to take up LLaMa if it's available.
Very cool! What I would suggest doing is starting from the Flax GPT-Neo model (since this is the Flax model most similar to LLaMa) and then adding the new bits in
@sanchit-gandhi Would love to take on https://huggingface.co/openai-gpt. I just hope inferencing on my mac works out
@sanchit-gandhi .hello, I would like to work on TAPAS.