wenet
wenet copied to clipboard
Methods to solve or mitigate the AMP training (inference) failing problem
Hi WeNet team, I noticed that AMP training may fail for some model configs or datasets. So I propose two methods which may solve or mitigate this problem, for your consideration.
Problem
I tried AMP training following the aishell recipe in wenet/examples/aishell/s0
with increased the model size output_size: 512
and attention_heads: 8
. From Epoch 14, it starts to show WARNING NaN or Inf found in input tensor
. From Epoch 16, it shows lots of "NaN or Inf", and I got cv loss = 478.36
. The training loss keeps increasing and eventually making the training fails.
Investigation
After a bit of investigation, I found that
in wenet/transformer/attention.py
, class RelPositionMultiHeadedAttention(MultiHeadedAttention)
, function forward(…)
, the following steps may result in tensors with large numerical value exceeding the fp16's max allowed value 65000
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
E.g. Max value of matrix_ac can be 50944, max value of matrix_bd can be 18800, so their sum scores
become Inf in fp16.
This problem can be solved by 2 methods.
Solution 1 - Move division ahead
The above code section can be modified to be
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) / math.sqrt(self.d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) / math.sqrt(self.d_k)
...
scores = matrix_ac + matrix_bd
With this method, I am able to do AMP training for 200 epochs, getting WER 5.97 %, cv_loss 4.00. "NaN or Inf" may appear a few times, but doesn't cause training to fail anymore. This method doesn't change model structure, so it can be used with any trained model to allow fp16 inference.
Solution 2
Another solution is to modify the embedding layer.
That is in file wenet/transformer/embedding.py
, class RelPositionalEncoding
, function forward(..)
, remove the line x = x * self.xscale
With this method, I am able to do AMP training for 120 epochs, getting WER 6.05%, cv_loss 3.89. Indeed, keep training for more epochs should not be a problem as well.
This method changed the model's structure a little, so it can not be used with any already trained model.
Hi @qijiaxing can you show the results of these two experimental Baseline? thanks.
Here is the training loss of baseline, i.e. amp training from scratch
Here is the training loss of solution 1.
Here is the training loss of solution 2
This issue has been automatically closed due to inactivity.