lightning-transformers
lightning-transformers copied to clipboard
Feature/asr
Initial PR for ASR support
hey @SeanNaren just made the PR, would be great if you could take a look and share some insights :)
Hi @rafaelvp-db thanks for making the PR!
There are quite a few pieces missing, hopefully, I can assist in helping you get the right things implemented!
Firstly I don't think your current code is entirely correct (unless I've made a mistake). The Wav2Vec model/dataset/tokenizer are completely different and should probably exist as new classes inherited from TaskTransformer
/TransformerDataModule
.
What I think would be a good idea is to get this blog post implemented into a TaskTransformer and a TransformerDataModule as you've already outlined: https://huggingface.co/blog/fine-tune-wav2vec2-english
This would involve
- Creating the logic to run the
training_step
validation_step
andtest_step
in theSpeechRecognitionTransformer
. Using the WER metric found in torchmetrics: https://torchmetrics.readthedocs.io/en/stable/text/word_error_rate.html?highlight=WER - Create an actual dataset using the pre-processing logic and whatnot found in the blog post, putting whatever processing logic in
SpeechRecognitionDataModule
Overall I would assume something like this to work:
import pytorch_lightning as pl
from lightning_transformers.task.audio.speech_recognition import (
SpeechRecognitionDataConfig,
SpeechRecognitionDataModule,
SpeechRecognitionTransformer,
)
if __name__ == "__main__":
model = SpeechRecognitionTransformer("facebook/wav2vec2-base", ctc_loss_reduction="mean", vocab_file="vocab.json")
dm = SpeechRecognitionDataModule(
cfg=SpeechRecognitionDataConfig(
batch_size=1,
dataset_name="timit_asr",
),
tokenizer=model.tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)
trainer.fit(model, dm)
Thanks for the guidelines @SeanNaren! Let me look into that.
Codecov Report
Merging #251 (a3d8528) into master (9f25baa) will decrease coverage by
1%
. The diff coverage is58%
.
:exclamation: Current head a3d8528 differs from pull request most recent head 45cbab7. Consider uploading reports for the commit 45cbab7 to get more accurate results
@@ Coverage Diff @@
## master #251 +/- ##
=====================================
- Coverage 75% 74% -1%
=====================================
Files 73 77 +4
Lines 1622 1682 +60
=====================================
+ Hits 1210 1245 +35
- Misses 412 437 +25
Hows it going @rafaelvp-db?
The code looks muuuch nicer, amazing job! Anything I can assist with? I notice that the example requires a vocab.json, I'm sure we can YOLO it and use the alphabet.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.