inference icon indicating copy to clipboard operation
inference copied to clipboard

RNN-T: enable batch decoder

Open chunyuan-w opened this issue 3 years ago • 3 comments

Pitch

Optimize the decoder of RNN-T to support batch mode.

Motivation

The current RNN-T decoder uses a for loop on batch_size and can only process BS = 1 each time (https://github.com/mlcommons/inference/blob/master/speech_recognition/rnnt/pytorch/decoders.py#L66-L71):

for batch_idx in range(logits.size(0)):
    inseq = logits[batch_idx, :, :].unsqueeze(1)
    # inseq: TxBxF
    logitlen = logits_lens[batch_idx]
    sentence = self._greedy_decode(inseq, logitlen)
    output.append(sentence)

In throughput mode where BS > 1, the current implementation is inefficient. We propose an optimization of the greedy decoder to handle the batch mode in this PR.

Data

  With the original decoder With batch decoder
WER 7.452253714852645%  7.452253714852645%

chunyuan-w avatar Apr 22 '22 06:04 chunyuan-w

MLCommons CLA bot:
Thank you for your submission, we really appreciate it. We ask that you sign our MLCommons CLA and be a member before we can accept your contribution. If you are interested in membership, please contact [email protected] .
0 out of 1 committers have signed the MLCommons CLA.
:x: @chunyuan-w
You can retrigger this bot by commenting recheck in this Pull Request

github-actions[bot] avatar Apr 22 '22 06:04 github-actions[bot]

This pull request introduces 3 alerts when merging 6f23edcea21ea0ee610736606cd6d1a6ac6b2c1c into de6497f9d64b85668f2ab9c26c9e3889a7be257b - view on LGTM.com

new alerts:

  • 1 for Testing equality to None
  • 1 for Unused local variable
  • 1 for Variable defined multiple times

lgtm-com[bot] avatar Apr 22 '22 06:04 lgtm-com[bot]

@galv Can you please review the PR?

rnaidu02 avatar Jun 14 '22 15:06 rnaidu02

@ashwin Can someone in NV review this PR?

rnaidu02 avatar Feb 07 '23 01:02 rnaidu02

Hi, suggest to hold on the integration considering below issue.

This PR is aimed to give a batched version of decoder, and then make the model can end-to-end infer under BS>1. However, to ensure computation correctness and compliance test, the preprocessor and encoder part should also be modified.(which the PR has not included) Because the padding part of input involeved by batching will participate in the end-to-end computation. It may occur not only in T(seq-lens) dimension, but also in C(channel) dimension(due to preprocessor::FrameSplicing and encoder::StackTime). The number and value of a same sample's padding part may be different by padded to different max_len in difference batches, which will cause the accuracy be unstable. You can also find more details of the issue under this link(https://github.com/mlcommons/inference/issues/801).

dbyoung18 avatar Feb 08 '23 06:02 dbyoung18