transformers
transformers copied to clipboard
[Flax Examples] Seq2Seq ASR Fine-Tuning Script
What does this PR do?
Can be used to fine-tune Flax Whisper for speech recognition.
Tested and verified as working with the following (dummy) config:
run_flax_speech_recognition_seq2seq.py \
--model_name_or_path openai/whisper-tiny.en \
--dataset_name hf-internal-testing/librispeech_asr_dummy \
--dataset_config clean \
--train_split_name validation \
--eval_split_name validation \
--output_dir whisper-tiny-ft-dummy \
--overwrite_output_dir \
--num_train_epochs=2 \
--max_train_samples 10 \
--max_eval_samples 10 \
--warmup_steps=8 \
--do_train \
--do_eval \
--learning_rate=2e-4 \
--per_device_train_batch_size=2 \
--per_device_eval_batch_size=1 \
--predict_with_generate
Will add a README with preliminary training configs / results later this week after doing a full fine-tuning run.
cc @peregilk @andyehrenberg for interest
The documentation is not available anymore as the PR was closed or merged.
@sanchit-gandhi @andyehrenberg
We have made a version of this script will support streaming and training on the TPU pods.
The current version of the script is available here: https://github.com/NbAiLab/nb-whisper/blob/main/run_flax_speech_recognition_seq2seq_streaming.py
We are however struggling with a bug at the moment. The script seems to work for training the Tiny models on multiple pod sizes. Both for scaling for speed and for increasing the batch size. All the other model sizes (small, base, medium, large) also works on the single TPU v4-8. However, training on the non-Tiny-model sizes runs for a few steps then freezes.
If anyone have any idea about this could be happening, I really appreciate it.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Given the popularity of the PyTorch fine-tuning script and Whisper JAX, it's a pretty easy addition adding a Whisper fine-tuning script in JAX/Flax.
Note: this is largely based off the distil-whisper training script, but simplified to run offline, with just 1 training dataset and the cross-entropy objective https://github.com/huggingface/distil-whisper#training