query2labels icon indicating copy to clipboard operation
query2labels copied to clipboard

Question about FrozenBatchNorm

Open lbx73737373 opened this issue 2 years ago • 1 comments

Execllent work! I have noticed that you adopted FrozenBatchNorm2d in your code. You mentioned it in your code adopting this method to prevent any other models than resnets producing nans. But you applied it by default for every network including resnets. What is that for? After all, intuitively speaking, it is better to use batchnorm for training purposes.

class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool,
                 pretrained: bool = True):
        if name in ['resnet18', 'resnet50', 'resnet34', 'resnet101']:
            backbone = getattr(torchvision.models, name)(
                replace_stride_with_dilation=[False, False, dilation],
                pretrained=True,
                norm_layer=FrozenBatchNorm2d)

Any information would be appreciated !

lbx73737373 avatar Mar 18 '23 15:03 lbx73737373

This trick is mainly used for object detection with very small batches-size. Nevertheless, the batch size in this method is relatively large, so I suggest you remove FrozenBatchNorm2d.

sorrowyn avatar Apr 01 '23 13:04 sorrowyn