swa-tutorials-pytorch
swa-tutorials-pytorch copied to clipboard
Batch Norm update doesn't require cpu
Hi,
First, thanks for your work. It's pretty cool ! I just wanted to clarify something.
In your example code, in order to update the batch normalization statistics at the end of training, you wrote:
swa_model = swa_model.cpu()
torch.optim.swa_utils.update_bn(train_loader, swa_model)
swa_model = swa_model.cuda()
but the update_bn function does accept a device keyword argument device in order to use cuda if possible (see https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py#L124).
So actually a better way to write your code would be something like:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.optim.swa_utils.update_bn(train_loader, swa_model, device)