RNN-T: enable batch decoder
Pitch
Optimize the decoder of RNN-T to support batch mode.
Motivation
The current RNN-T decoder uses a for loop on batch_size and can only process BS = 1 each time (https://github.com/mlcommons/inference/blob/master/speech_recognition/rnnt/pytorch/decoders.py#L66-L71):
for batch_idx in range(logits.size(0)):
inseq = logits[batch_idx, :, :].unsqueeze(1)
# inseq: TxBxF
logitlen = logits_lens[batch_idx]
sentence = self._greedy_decode(inseq, logitlen)
output.append(sentence)
In throughput mode where BS > 1, the current implementation is inefficient. We propose an optimization of the greedy decoder to handle the batch mode in this PR.
Data
| With the original decoder | With batch decoder | |
|---|---|---|
| WER | 7.452253714852645% | 7.452253714852645% |
MLCommons CLA bot:
Thank you for your submission, we really appreciate it. We ask that you sign our MLCommons CLA and be a member before we can accept your contribution. If you are interested in membership, please contact [email protected] .
0 out of 1 committers have signed the MLCommons CLA.
:x: @chunyuan-w
You can retrigger this bot by commenting recheck in this Pull Request
This pull request introduces 3 alerts when merging 6f23edcea21ea0ee610736606cd6d1a6ac6b2c1c into de6497f9d64b85668f2ab9c26c9e3889a7be257b - view on LGTM.com
new alerts:
- 1 for Testing equality to None
- 1 for Unused local variable
- 1 for Variable defined multiple times
@galv Can you please review the PR?
@ashwin Can someone in NV review this PR?
Hi, suggest to hold on the integration considering below issue.
This PR is aimed to give a batched version of decoder, and then make the model can end-to-end infer under BS>1. However, to ensure computation correctness and compliance test, the preprocessor and encoder part should also be modified.(which the PR has not included) Because the padding part of input involeved by batching will participate in the end-to-end computation. It may occur not only in T(seq-lens) dimension, but also in C(channel) dimension(due to preprocessor::FrameSplicing and encoder::StackTime). The number and value of a same sample's padding part may be different by padded to different max_len in difference batches, which will cause the accuracy be unstable. You can also find more details of the issue under this link(https://github.com/mlcommons/inference/issues/801).