fairseq icon indicating copy to clipboard operation
fairseq copied to clipboard

fp16 bug about grad_norm computation (resulting in min loss scale reached)

Open flycser opened this issue 1 year ago • 3 comments

🐛 Bug

FloatingPointError: Minimum loss scale reached (0.0001).

I met this assertation often when do training in fp16. I found it may be avoided by modifying the way to calculating grad norm.

The reason is that in current version of fairseq (>=0.10). The grad norm is calculated on scaled and not averaged fp32 gradient and then be unscaled and averaged. In this case, the loss is very big since it is not unscaled and not averaged (multiplied by num/sample_size). The coefficient num/sample_size is just multiplied by the multiply_factor in fp16_optimizer and the multiply_factor will be multiplied by loss until optimizer step(i.e. gradient update). An unscaled and not-averaged gradient definitely results in overflow error more often when calculating grad norm. While in version of 0.8, the average operation will be done instantly when calling multiply_grads() function, which can avoid some overflow error.

This error will emerge often when training a speech recognition model in which the sample size (frame) is large.

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Run cmd '....'
  2. See error

Code sample

Expected behavior

Environment

  • fairseq Version (e.g., 1.0 or main):
  • PyTorch Version (e.g., 1.0)
  • OS (e.g., Linux):
  • How you installed fairseq (pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

flycser avatar Jul 08 '22 02:07 flycser

Hi, I have the same problem, what should I do to fix this bug, thanks!

LoganLiu66 avatar Aug 12 '22 08:08 LoganLiu66

Hi, I have the same problem, what should I do to fix this bug, thanks!

Just have a look at the fp16 optimizer in version 0.8. Call the multiply_grads before calculate the norm.

flycser avatar Aug 15 '22 01:08 flycser

That may not be the reason because norm is computed in FP32. code.

total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)

funtion avatar Oct 25 '22 12:10 funtion