cudnn.torch icon indicating copy to clipboard operation
cudnn.torch copied to clipboard

Loss is NaN when using half precision

Open mbcel opened this issue 8 years ago • 3 comments

When I run my model on half precision(fp16) the Loss function returns NaN. It all works fine when I use normal floating point precision (fp32) so I don't think it is a problem of the learning parameters. It also is NaN right from the beginning of the training.

I am using the SpatialCrossEntropyCriterion and I also do explicitly not convert every MaxPooling and BatchNormalization to cudnn since these don't work otherwise.

Relevant code:

criterion = cudnn.SpatialCrossEntropyCriterion(classWeights):cudaHalf()

 model = createNet()
 model = model:cudaHalf()

  -- cudnn ignoeres Pooling layers due to compatibiliy problems with Unpooling
  cudnn.convert(model, cudnn, function(module)
        return torch.type(module):find("SpatialMaxPooling") ~= nil -- compatibility problems
        or torch.type(module):find("SpatialBatchNormalization") ~= nil -- apparently no cudaHalf implementation
      end)

...
-- during training this returns nan right from beginning or sometimes at second iteration
loss = criterion:forward(outputGpu, labels)

I am wondering if the reason is the (not existing?) CudaHalf implementation for the BatchNormalization module?

mbcel avatar Jul 18 '17 12:07 mbcel

Okay I figured out that the nan's were due the adam optimisation. The default epsilon of 1e-8 is too low and rounded to zero like pointed out here. Setting it to 1e-4 fixes the nan problem but now the optimisation does not decrease the loss anymore. Is there a way to solve this wile keeping the same learning rate?

mbcel avatar Jul 18 '17 20:07 mbcel

You can keep FP32 for the optimizer as explained here : https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/ And a pytorch snippet : https://gist.github.com/ajbrock/075c0ca4036dc4d8581990a6e76e07a3

Manuscrit avatar Feb 13 '19 20:02 Manuscrit

I solved this issue by using autocast instead of .half(), which was from suggestion of PyTorch team.
https://discuss.pytorch.org/t/working-with-half-model-and-half-input/88494
https://pytorch.org/docs/master/amp.html

LiJiaqi96 avatar Dec 28 '21 03:12 LiJiaqi96