transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[Flax Examples] Seq2Seq ASR Fine-Tuning Script

Open sanchit-gandhi opened this issue 1 year ago • 2 comments

What does this PR do?

Can be used to fine-tune Flax Whisper for speech recognition.

Tested and verified as working with the following (dummy) config:

run_flax_speech_recognition_seq2seq.py \
            --model_name_or_path openai/whisper-tiny.en \
            --dataset_name hf-internal-testing/librispeech_asr_dummy \
            --dataset_config clean \
            --train_split_name validation \
            --eval_split_name validation \
            --output_dir whisper-tiny-ft-dummy \
            --overwrite_output_dir \
            --num_train_epochs=2 \
            --max_train_samples 10 \
            --max_eval_samples 10 \
            --warmup_steps=8 \
            --do_train \
            --do_eval \
            --learning_rate=2e-4 \
            --per_device_train_batch_size=2 \
            --per_device_eval_batch_size=1 \
            --predict_with_generate

Will add a README with preliminary training configs / results later this week after doing a full fine-tuning run.

cc @peregilk @andyehrenberg for interest

sanchit-gandhi avatar Feb 23 '23 16:02 sanchit-gandhi

The documentation is not available anymore as the PR was closed or merged.

@sanchit-gandhi @andyehrenberg

We have made a version of this script will support streaming and training on the TPU pods.

The current version of the script is available here: https://github.com/NbAiLab/nb-whisper/blob/main/run_flax_speech_recognition_seq2seq_streaming.py

We are however struggling with a bug at the moment. The script seems to work for training the Tiny models on multiple pod sizes. Both for scaling for speed and for increasing the batch size. All the other model sizes (small, base, medium, large) also works on the single TPU v4-8. However, training on the non-Tiny-model sizes runs for a few steps then freezes.

If anyone have any idea about this could be happening, I really appreciate it.

peregilk avatar Apr 03 '23 15:04 peregilk

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 05 '23 08:09 github-actions[bot]

Given the popularity of the PyTorch fine-tuning script and Whisper JAX, it's a pretty easy addition adding a Whisper fine-tuning script in JAX/Flax.

Note: this is largely based off the distil-whisper training script, but simplified to run offline, with just 1 training dataset and the cross-entropy objective https://github.com/huggingface/distil-whisper#training

sanchit-gandhi avatar Sep 28 '23 18:09 sanchit-gandhi