Why do you use `BatchRenorm` instead of `nn.BatchNorm3d`?
I found that you implemented a BatchRenorm module in your code. I wonder that why you did't use nn.BatchNorm3d of pytorch directly? I hope you can explain this detail, thanks a lot!
BatchRenorm is an implementation of the batch re-normalization paper. I experimented with different normalization layers: BatchRenorm, nn.BatchNorm3D, nn.GroupNorm, nn.InstanceNorm3, and nn.LayerNorm. BatchRenorm performed best.
nn.BatchNorm3d in particular led to unstable training. I observed mIoU periodically dropping to almost 0 on the test set and then recovering, after a certain number of training iterations. This is probably because of the combination of batch norm, a rather small batch size (4, due to memory pressure from the 3D grid), and using an Adam optimizer.