lightning-transformers icon indicating copy to clipboard operation
lightning-transformers copied to clipboard

Feature/asr

Open rafaelvp-db opened this issue 2 years ago • 6 comments

Initial PR for ASR support

rafaelvp-db avatar May 20 '22 08:05 rafaelvp-db

hey @SeanNaren just made the PR, would be great if you could take a look and share some insights :)

rafaelvp-db avatar May 20 '22 08:05 rafaelvp-db

Hi @rafaelvp-db thanks for making the PR!

There are quite a few pieces missing, hopefully, I can assist in helping you get the right things implemented!

Firstly I don't think your current code is entirely correct (unless I've made a mistake). The Wav2Vec model/dataset/tokenizer are completely different and should probably exist as new classes inherited from TaskTransformer/TransformerDataModule.

What I think would be a good idea is to get this blog post implemented into a TaskTransformer and a TransformerDataModule as you've already outlined: https://huggingface.co/blog/fine-tune-wav2vec2-english

This would involve

  • Creating the logic to run the training_step validation_step and test_step in the SpeechRecognitionTransformer. Using the WER metric found in torchmetrics: https://torchmetrics.readthedocs.io/en/stable/text/word_error_rate.html?highlight=WER
  • Create an actual dataset using the pre-processing logic and whatnot found in the blog post, putting whatever processing logic in SpeechRecognitionDataModule

Overall I would assume something like this to work:

import pytorch_lightning as pl

from lightning_transformers.task.audio.speech_recognition import (
    SpeechRecognitionDataConfig,
    SpeechRecognitionDataModule,
    SpeechRecognitionTransformer,
)

if __name__ == "__main__":
    model = SpeechRecognitionTransformer("facebook/wav2vec2-base", ctc_loss_reduction="mean", vocab_file="vocab.json")
    dm = SpeechRecognitionDataModule(
        cfg=SpeechRecognitionDataConfig(
            batch_size=1,
            dataset_name="timit_asr",
        ),
        tokenizer=model.tokenizer,
    )
    trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

    trainer.fit(model, dm)

SeanNaren avatar May 23 '22 09:05 SeanNaren

Thanks for the guidelines @SeanNaren! Let me look into that.

rafaelvp-db avatar May 25 '22 19:05 rafaelvp-db

Codecov Report

Merging #251 (a3d8528) into master (9f25baa) will decrease coverage by 1%. The diff coverage is 58%.

:exclamation: Current head a3d8528 differs from pull request most recent head 45cbab7. Consider uploading reports for the commit 45cbab7 to get more accurate results

@@          Coverage Diff          @@
##           master   #251   +/-   ##
=====================================
- Coverage      75%    74%   -1%     
=====================================
  Files          73     77    +4     
  Lines        1622   1682   +60     
=====================================
+ Hits         1210   1245   +35     
- Misses        412    437   +25     

codecov[bot] avatar May 29 '22 08:05 codecov[bot]

Hows it going @rafaelvp-db?

The code looks muuuch nicer, amazing job! Anything I can assist with? I notice that the example requires a vocab.json, I'm sure we can YOLO it and use the alphabet.

SeanNaren avatar May 30 '22 12:05 SeanNaren

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Sep 20 '22 22:09 stale[bot]