keras-contrib
keras-contrib copied to clipboard
Unneccessary ValueError in instancenormalization.py
I'm failing to see the reasoning behind these lines: https://github.com/keras-team/keras-contrib/blob/5ffab172661411218e517a50170bb97760ea567b/keras_contrib/layers/normalization/instancenormalization.py#L80-L81
Say I want to apply an InstanceNormalization layer to tensors all different ranks (e.g. coming from Dense and Conv2D layers), and I always want InstanceNormalization to be applied to the feature axis. If I have data_format == 'channels_last'
I think it would make sense that I could simply pass axis=-1
to all my InstanceNormalization layers.
However, because of the above two lines, this doesn't work for rank 1 tensors (coming from e.g. Dense layers).
I deleted the above lines and my solution with axis=-1
for all layers now works.
Shouldn't axis=1
and axis=-1
be perfectly valid values when passing rank 1 tensors?