wenet icon indicating copy to clipboard operation
wenet copied to clipboard

Methods to solve or mitigate the AMP training (inference) failing problem

Open qijiaxing opened this issue 3 years ago • 2 comments

Hi WeNet team, I noticed that AMP training may fail for some model configs or datasets. So I propose two methods which may solve or mitigate this problem, for your consideration.

Problem

I tried AMP training following the aishell recipe in wenet/examples/aishell/s0 with increased the model size output_size: 512 and attention_heads: 8. From Epoch 14, it starts to show WARNING NaN or Inf found in input tensor. From Epoch 16, it shows lots of "NaN or Inf", and I got cv loss = 478.36. The training loss keeps increasing and eventually making the training fails.

Investigation

After a bit of investigation, I found that in wenet/transformer/attention.py, class RelPositionMultiHeadedAttention(MultiHeadedAttention), function forward(…), the following steps may result in tensors with large numerical value exceeding the fp16's max allowed value 65000

q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)

E.g. Max value of matrix_ac can be 50944, max value of matrix_bd can be 18800, so their sum scores become Inf in fp16.

This problem can be solved by 2 methods.

Solution 1 - Move division ahead

The above code section can be modified to be

q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) / math.sqrt(self.d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) / math.sqrt(self.d_k)
...
scores = matrix_ac + matrix_bd

With this method, I am able to do AMP training for 200 epochs, getting WER 5.97 %, cv_loss 4.00. "NaN or Inf" may appear a few times, but doesn't cause training to fail anymore. This method doesn't change model structure, so it can be used with any trained model to allow fp16 inference.

Solution 2

Another solution is to modify the embedding layer. That is in file wenet/transformer/embedding.py, class RelPositionalEncoding, function forward(..), remove the line x = x * self.xscale With this method, I am able to do AMP training for 120 epochs, getting WER 6.05%, cv_loss 3.89. Indeed, keep training for more epochs should not be a problem as well. This method changed the model's structure a little, so it can not be used with any already trained model.

qijiaxing avatar Dec 28 '21 06:12 qijiaxing

Hi @qijiaxing can you show the results of these two experimental Baseline? thanks.

yushanyong avatar Dec 29 '21 02:12 yushanyong

Here is the training loss of baseline, i.e. amp training from scratch image

Here is the training loss of solution 1. image

Here is the training loss of solution 2 image

qijiaxing avatar Dec 29 '21 02:12 qijiaxing

This issue has been automatically closed due to inactivity.

github-actions[bot] avatar Feb 10 '24 01:02 github-actions[bot]