keras
keras copied to clipboard
BatchNormalization gives incorrect output with masked inputs > 3 dimensions
The mean/variance calculations are incorrect, which means the inputs are not normalized correctly. E.g.
import keras
x = keras.ops.ones((1, 2, 3, 4))
x._keras_mask = keras.ops.ones((1, 2, 1))
y = keras.layers.BatchNormalization()(x, training=True)
print(keras.ops.mean(y, axis=-1))
gives output
tf.Tensor([-0.57732624 -0.57732624 -0.57732624 -0.57732624], shape=(4,), dtype=float32)
instead of the correct normalized output ([0, 0, 0, 0]).
The basic issue is that this calculation is incorrect: https://github.com/keras-team/keras/blob/efaaf85e19113400f23462cbafcef433cd95ad9c/keras/src/layers/normalization/batch_normalization.py#L310-L314 because it doesn't account for the broadcasting (i.e. it gives a value of 2 in the above example, when it should be 2 * 3 * 4).
See https://github.com/keras-team/keras/issues/19818 for more discussion/background.
I think a better workaround is to validate the shape of the mask in keras.
The shape of the mask is correct in this example (according to https://github.com/keras-team/keras/issues/19818#issuecomment-2156142266), so validation wouldn't help in this case.
because it doesn't account for the broadcasting (i.e. it gives a value of 2 in the above example, when it should be 2 * 3 * 4).
broadcasting from (2,) to (2, 3, 4) makes sense here, but elsewhere, "broadcasting" may starts with the rightmost dimension, i.e. broadcast (4,) to (2, 3, 4)