transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add Flax Whisper for audio classification

Open sanchit-gandhi opened this issue 1 year ago • 9 comments

Feature request

The PR https://github.com/huggingface/transformers/pull/21754 adds the PyTorch version of WhisperForAudioClassification. It would be great to add the Flax equivalent for cross-library equivalence ♻️

Motivation

Whisper is an encoder-decoder model for speech recognition. However, we can repurpose the model for other speech tasks, such as audio classification.

Audio classification is the task of mapping from an input speech sequence to a single class prediction. For more details, refer to the task page on the Hub: https://huggingface.co/tasks/audio-classification

For audio classification, we only require a single model output. Thus, we do not need the auto-regressive generation capacities of the Whisper decoder (which is used to generate a sequence of text tokens during speech recognition). Instead, we can just use the Whisper encoder to get hidden states, and add a classification head on top to make class label predictions.

This is analogous to using a Wav2Vec2 model for audio classification: the Wav2Vec2 encoder is used to get hidden states, and a classification head added on top to make class label predictions.

The PR https://github.com/huggingface/transformers/pull/21754 adds the PyTorch version of WhisperForAudioClassification. It required adding a projection layer and classification layer on top of the WhisperEncoder. For more details, refer directly to the pull request.

It would be great to add the Flax equivalent of this model for cross-framework support.

The most difficult part of this PR will be getting the model tester to work. You can see from the PyTorch PR that we require a standalone tester for the audio classification model. This is because the original Whisper model is an encoder-decoder model, but the audio classification model is an encoder-only model. Thus, we require different testing logic.

Your contribution

Opening this one up to the community! This will be quite a fun JAX/Flax PR! 🚀

If you're interested in tackling this, free to drop a comment in this thread and open a PR when you're ready. More than happy to answer any questions / queries about this integration!

sanchit-gandhi avatar Feb 24 '23 08:02 sanchit-gandhi

Hi. It's my first time contributing to open source. I want to tackle this issue. How can I get started?

MBora avatar Feb 25 '23 19:02 MBora

I have contributed to a few good first issues on HF, would like to take this to learn JAX if available!

yhl48 avatar Feb 28 '23 00:02 yhl48

@Potato-Cracker , @yhl48 Are you guys currently working on it? I have a working branch locally with passing tests, but if you guys would like to make PR, that's totally cool too.

Shubhamai avatar Mar 02 '23 06:03 Shubhamai

@Shubhamai please go ahead with the PR!

yhl48 avatar Mar 02 '23 10:03 yhl48

Uh, looks like a PR is already submitted :smile: , I will see if I can assist the linked PR.

Shubhamai avatar Mar 03 '23 14:03 Shubhamai

Very cool that there's so much interest in adding Flax models! Great to see that the JAX/Flax community is so active 🙌 Would you guys be interested in finding other PyTorch models to port to JAX/Flax in transformers?

sanchit-gandhi avatar Mar 17 '23 15:03 sanchit-gandhi

@sanchit-gandhi I will be happy to contribute to some of them, would be great if you have any suggestions on any particular models!

yhl48 avatar Mar 17 '23 16:03 yhl48

Very cool! You can take a look at the model integration table here: https://github.com/huggingface/transformers/blob/main/docs/source/en/index.mdx#supported-frameworks

There are a bunch of popular models that are supported in PyTorch but not Flax, LLaMa being one of them! This could be a cool model addition if you're interested?

sanchit-gandhi avatar Mar 24 '23 13:03 sanchit-gandhi

I would love to take up LLaMa if it's available.

Shubhamai avatar Mar 24 '23 17:03 Shubhamai

Very cool! What I would suggest doing is starting from the Flax GPT-Neo model (since this is the Flax model most similar to LLaMa) and then adding the new bits in

sanchit-gandhi avatar Apr 04 '23 15:04 sanchit-gandhi

@sanchit-gandhi Would love to take on https://huggingface.co/openai-gpt. I just hope inferencing on my mac works out

mayankagarwals avatar Apr 06 '23 16:04 mayankagarwals

@sanchit-gandhi .hello, I would like to work on TAPAS.

elabongaatuo avatar Apr 18 '23 12:04 elabongaatuo