Use model-cast-to-bfloat16 rather than AMP-to-bfloat16 for inference.
What does this PR do ?
I demonstrate, using transcribe_speech.py, that simply casting the entire model to bfloat16 gives about 15% higher performance than using automatic mixed precision. The reasons why are discussed in #9086
There are a few small modifications to our conformer encoder and multi head attention implementations to make this work. Basically, torch.float32 was unintentionally used at a few points.
I also disable updates to the batch norm statistics in this PR during inference. See the relevant comment in module.py.
Note that casting the preprocessor's input waveform to bfloat16 causes a serious accuracy degredation. I provide a warning about this. It is better to just do the preprocessor in float32, and then cast its output to bfloat16 (which is what I do).
I have verified that this works with Parakeet CTC 1.1B and Parakeet RNN-T 1.1B. I will upload a table showing WER and RTFx throughput when running transcribe_speech.py, using the new casting method and the old AMP method. RTFx improves by about 15%, and WER stays about the same.
This image demonstrates (1) the increase in RTFx across a range of datasets and (2) the fact that WER does not degrade.
Collection: ASR
Changelog
- Add specific line by line info of high level changes in this PR.
Usage
- See transcribe_speech.py. Note that several parts of the code use AMP rather than simply casting the model.
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR. To re-run CI remove and add the label again. To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
- [ ] Make sure you read and followed Contributor guidelines
- [ ] Did you write any new necessary tests?
- [ ] Did you add or update any necessary documentation?
- [ ] Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
- [ ] Reviewer: Does the PR have correct import guards for all optional libraries?
PR Type:
- [ ] New Feature
- [X] Bugfix
- [ ] Documentation
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed. Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information
- Related to #9086
FYI, the PR description has been updated to describe the benefit from using this mode of execution instead of AMP
This is ready for another round of review. Unfortunately the black + isort change made a lot of noise. But basically I removed a lot of cruft. This PR basically does only three things now:
- Modify transcribe_speech.py to run using casting to float16/bfloat16 instead of AMP
- Modify transcribe_speech.py to output RTFx scores, using the calculate_rtfx=True config. It is set to false by default (but I wouldn't be opposed to setting it to True by default later for the sake of educating users about their throughput results).
- Make small changes to the positional encoding buffer and a call to masked_fill() to make sure that data is not accidentally casted up to float32.
@nithinraok @pzelasko @titu1994 I think this change is good to go at this point.
Looks good to me!