Synchronized-BatchNorm-PyTorch icon indicating copy to clipboard operation
Synchronized-BatchNorm-PyTorch copied to clipboard

Train Stucked

Open YinengXiong opened this issue 3 years ago • 4 comments

Hi ~ I Use if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model = bnconvert(model) model.cuda() to use sync-bn during multi-gpu training, but when training the network, it looks like training procedure stucked at final batch in one epoch

image

YinengXiong avatar Mar 29 '21 10:03 YinengXiong

The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.

Can you check this?

vacancy avatar Mar 29 '21 14:03 vacancy

So if I want to use model e.g. torchvision, which built with nn.BatchNorm I should: 0. model = torchvision.models.resnet50()

  1. model = bnconvert(model)
  2. model = DataParallelWithCallback(model)
  3. model.cuda()

Am I right?

YinengXiong avatar Apr 01 '21 10:04 YinengXiong

Correct. I suspect the reason is the following:

The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.

vacancy avatar Apr 07 '21 18:04 vacancy

Correct. I suspect the reason is the following:

The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.

thanks a lot

YinengXiong avatar Apr 08 '21 03:04 YinengXiong