swa-tutorials-pytorch icon indicating copy to clipboard operation
swa-tutorials-pytorch copied to clipboard

Batch Norm update doesn't require cpu

Open ListIndexOutOfRange opened this issue 3 years ago • 0 comments

Hi,

First, thanks for your work. It's pretty cool ! I just wanted to clarify something.

In your example code, in order to update the batch normalization statistics at the end of training, you wrote:

swa_model = swa_model.cpu()
torch.optim.swa_utils.update_bn(train_loader, swa_model)
swa_model = swa_model.cuda() 

but the update_bn function does accept a device keyword argument device in order to use cuda if possible (see https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py#L124).

So actually a better way to write your code would be something like:

device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.optim.swa_utils.update_bn(train_loader, swa_model, device)

ListIndexOutOfRange avatar Jan 06 '22 14:01 ListIndexOutOfRange