flax icon indicating copy to clipboard operation
flax copied to clipboard

Standardizing normalization layers

Open chiamp opened this issue 1 year ago • 2 comments

norm

LayerNorm is understood as normalization the activations by reducing across all non-batch axes. Currently Flax's implementation of LayerNorm, the default reduction_axes=-1. This works for 2D inputs, but for higher dimensional tensors, this would only reduce the trailing dimension. Should we change the default implementation so that it normalizes all non-batch axes by default (assuming the leading dimension is the batch axes)? This also applies to RMSNorm as well.

Another thing is that currently all normalization layers with learnable scale and bias have a feature_axis (or equivalent) input arg so that the user can specify the shape of the learnable params, except GroupNorm (which always defines feature_axis=-1). Should we add this into GroupNorm as well?

chiamp avatar Jan 27 '24 02:01 chiamp

This works for 2D inputs, but for higher dimensional tensors, this would only reduce the trailing dimension.

Interesting. Why do we specialize 2D LayerNorm? Also, where do where do we specialize it?

Should we add this into GroupNorm as well?

Yeah, sounds like a good idea! Thanks for looking into this.

cgarciae avatar Jan 29 '24 15:01 cgarciae

Interesting. Why do we specialize 2D LayerNorm? Also, where do where do we specialize it?

After internal discussion, we have decided to keep the default reduction_axes=-1

chiamp avatar Jan 30 '24 01:01 chiamp