WeightStandardization icon indicating copy to clipboard operation
WeightStandardization copied to clipboard

consistent mean calculation with std calculation

Open huangeddie opened this issue 4 years ago • 2 comments

Hi in the Conv2d with WS, you computed the mean as weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)

why don't you do something like weight_mean = weight.view(weight.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1)

it would be consistent with what you did with the standard deviation as well

huangeddie avatar Aug 24 '20 21:08 huangeddie

I also don't think the expand_as in weight = (weight - weight_mean) / std.expand_as(weight) is necessary since std should have the same number of dimensions as weight

huangeddie avatar Aug 24 '20 21:08 huangeddie

You can take a look at

https://github.com/open-mmlab/mmcv/blob/d5cbf7eed1269095bfba1a07913efbbc99d2d10b/mmcv/cnn/bricks/conv_ws.py

which provides a nicer-looking implementation.

joe-siyuan-qiao avatar Aug 25 '20 00:08 joe-siyuan-qiao