WeightStandardization
WeightStandardization copied to clipboard
consistent mean calculation with std calculation
Hi in the Conv2d with WS, you computed the mean as
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
why don't you do something like
weight_mean = weight.view(weight.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1)
it would be consistent with what you did with the standard deviation as well
I also don't think
the expand_as
in weight = (weight - weight_mean) / std.expand_as(weight)
is necessary since std
should have the same number of dimensions as weight
You can take a look at
https://github.com/open-mmlab/mmcv/blob/d5cbf7eed1269095bfba1a07913efbbc99d2d10b/mmcv/cnn/bricks/conv_ws.py
which provides a nicer-looking implementation.