fast_rnnt icon indicating copy to clipboard operation
fast_rnnt copied to clipboard

AssertionError: assert py.is_contiguous()

Open Anwarvic opened this issue 1 year ago • 3 comments

I'm working on integrating FastRNNT with Speechbrain, check this Pull Request.

At the current moment, I'm trying to train a transducer model on the multilingual TEDx dataset (mTEDx) for French. Whenever I train my model, I get this assertion error (he issue's title). However, it says in the mutual_information.py file that:

# The following assertions are for efficiency
assert px.is_contiguous()
assert py.is_contiguous()

Once I comment these two lines, everything works just fine. Using a transducer model with an encoder of wav2vec2 pre-trained model + one linear layer, and a one layer GRU as a decoder, the model trains just fine and I got 14.37 WER on the French test set which is way better than our baseline.

Now, I have these two questions:

  • How do I avoid getting this AssertionError?
  • Does commenting these two assertions hurt the performance?

Your guidance is much appreciated!

Anwarvic avatar Aug 02 '22 14:08 Anwarvic

If I remember correctly, the cpp code is using tensor accessor to access the data, which does not require a contiguous tensor.

But a contiguous tensor is more cache friendly, so I suggest changing it to

px = px.contiguous()

csukuangfj avatar Aug 02 '22 14:08 csukuangfj

So, theoretically commenting these two assertions won't affect the performance... right? And changing the tensors to contiguous will just help a little bit with memory?

Anwarvic avatar Aug 02 '22 16:08 Anwarvic

It says right there, it's for efficiency, so yes, using non-contiguous tensors will affect the performance. Making that copy may not necessarily require more memory, it depends whether the original (before the copy) is required for backprop. I suggest to try adding the .contiguous() statement before the log_softmax, if possible, since likely the log_softmax needs the output of its operation for the backprop (but not the input), so the copy prior to the .contiguous() before the log_softmax likely would not be held for backprop.

danpovey avatar Aug 02 '22 19:08 danpovey

@danpovey I'm sorry I didn't get what you mean by "adding the .contiguous() statement before the log_softmax".

By ".contiguous() statement", you meant px = px.contiguous() & py = py.contiguous().. right?

Also, which log_softmax are we talking about here exactly? The one at the end of the jointer network?

Anwarvic avatar Aug 09 '22 15:08 Anwarvic

At some point in the RNN-T computation there is a normalization of log-probs, probably via log_softmax(). I meant doing it just before then. But this is probably not super critical as I think this is not going to dominate memory requirements anyway; thanks to using pruned RNN-T, we are not instantiating any really huge tensors. So you can do it to the px and py, I think, if they are not naturally contiguous.

danpovey avatar Aug 09 '22 19:08 danpovey

I have added the following two lines just before this part in the mutual_information.py script:

if not px.is_contiguous(): px = px.contiguous()
if not py.is_contiguous(): py = py.contiguous()

@danpovey If you agree with what I did, feel free to close this issue!

Anwarvic avatar Aug 11 '22 09:08 Anwarvic

I think you don't need to check whether it is contiguous.

px.contiguous() is a no-op if px is already contiguous, I think.

csukuangfj avatar Aug 11 '22 09:08 csukuangfj

Thanks for the help!

Anwarvic avatar Aug 11 '22 09:08 Anwarvic

@Anwarvic Where do you add this line, I think there is px.contiguous in rnnt_loss.py https://github.com/danpovey/fast_rnnt/blob/c268c3d5a005968b87a724a21082410a3ec0bac3/fast_rnnt/python/fast_rnnt/rnnt_loss.py#L810-L811

pkufool avatar Aug 11 '22 10:08 pkufool

Ok, I think I forgot get_rnnt_logprobs and get_rnnt_logprobs_smoothed.

pkufool avatar Aug 11 '22 10:08 pkufool

My issue was in the AssertionError which only exists in the mutual_information.py script... I think.

Anwarvic avatar Aug 11 '22 12:08 Anwarvic

My issue was in the AssertionError which only exists in the mutual_information.py script... I think.

Yes, I meaned we won't call mutual_information_recursion directly, we call it from functions in rnnt_loss.py. Anyway, fix it in mutual_information.py is OK. Thanks!

pkufool avatar Aug 11 '22 23:08 pkufool