flax
flax copied to clipboard
Standardizing normalization layers
LayerNorm
is understood as normalization the activations by reducing across all non-batch axes. Currently Flax's implementation of LayerNorm
, the default reduction_axes=-1
. This works for 2D inputs, but for higher dimensional tensors, this would only reduce the trailing dimension. Should we change the default implementation so that it normalizes all non-batch axes by default (assuming the leading dimension is the batch axes)? This also applies to RMSNorm
as well.
Another thing is that currently all normalization layers with learnable scale and bias have a feature_axis
(or equivalent) input arg so that the user can specify the shape of the learnable params, except GroupNorm
(which always defines feature_axis=-1
). Should we add this into GroupNorm
as well?
This works for 2D inputs, but for higher dimensional tensors, this would only reduce the trailing dimension.
Interesting. Why do we specialize 2D LayerNorm? Also, where do where do we specialize it?
Should we add this into GroupNorm as well?
Yeah, sounds like a good idea! Thanks for looking into this.
Interesting. Why do we specialize 2D LayerNorm? Also, where do where do we specialize it?
After internal discussion, we have decided to keep the default reduction_axes=-1