Swin-Transformer icon indicating copy to clipboard operation
Swin-Transformer copied to clipboard

Got a nan loss and gradient norm when training swin-l on imagenet22k with O1

Open jiandan42 opened this issue 4 years ago • 5 comments
trafficstars

When I use the amp-opt-level O1 to train the swin-large_patch4_window7_224 on imagenet22k, I get a nan loss and grad_norm ever since epoch [1/60] iter [880/3466]。The training process is normal before, and the loss and grad_norm turn nan suddenly at epoch [1/60] iter [880/3466]。I follow the configuration in the paper with 4096 batchsize, 64 V100 and 64 batchsize each。The total epoch is set to 60 and warmup epoch set to 5。This problem didn't happen when I switch to amp-opt-level O0, so it seems to be ralated to mix precision training。 Is there any suggestion on how this was happend and how to solve this provblem so I can use mix precision training to accelerate the training process?

jiandan42 avatar Jun 22 '21 02:06 jiandan42

hi, did you solve the issue? i also encountered the satiation where the grad_norm.avg is nan, but other items are ok (loss and grad.val)

AI4Math-ShanZhang avatar Oct 28 '21 04:10 AI4Math-ShanZhang

hi, did you solve the issue? i also encountered the satiation where the grad_norm.avg is nan, but other items are ok (loss and grad.val)

It does not affect training. We will skip the training step if there is a nan.

ancientmooner avatar Dec 20 '21 09:12 ancientmooner

@jiandan42 Yes, it happens sometime. You can try setting grad_clip, or using the naive PyTorch fp16 support, or using DeepSpeed. We find the latter two mixed training frameworks are more stable than apex.

ancientmooner avatar Dec 20 '21 09:12 ancientmooner

hi, did you solve the issue? i also encountered the satiation where the grad_norm.avg is nan, but other items are ok (loss and grad.val)

It does not affect training. We will skip the training step if there is a nan.

Hi @ancientmooner. Could you please tell me how to skip the training step if there is a NaN? I couldn't find this in the training code. Thanks a lot!

netw0rkf10w avatar Oct 06 '22 15:10 netw0rkf10w

Use "torch.bfloat16" instead of "torch.float16" in AMP.

rajeevgl01 avatar Nov 22 '23 19:11 rajeevgl01