Synchronized-BatchNorm-PyTorch
Synchronized-BatchNorm-PyTorch copied to clipboard
Train Stucked
Hi ~
I Use
if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model = bnconvert(model) model.cuda()
to use sync-bn during multi-gpu training, but when training the network, it looks like training procedure stucked at final batch in one epoch
The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.
Can you check this?
So if I want to use model e.g. torchvision, which built with nn.BatchNorm I should: 0. model = torchvision.models.resnet50()
- model = bnconvert(model)
- model = DataParallelWithCallback(model)
- model.cuda()
Am I right?
Correct. I suspect the reason is the following:
The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.
Correct. I suspect the reason is the following:
The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.
thanks a lot