keras
keras copied to clipboard
BatchNormalization gives incorrect output with masked inputs > 3 dimensions
The mean/variance calculations are incorrect, which means the inputs are not normalized correctly. E.g.
import keras
x = keras.ops.ones((1, 2, 3, 4))
x._keras_mask = keras.ops.ones((1, 2, 1))
y = keras.layers.BatchNormalization()(x, training=True)
print(keras.ops.mean(y, axis=-1))
gives output
tf.Tensor([-0.57732624 -0.57732624 -0.57732624 -0.57732624], shape=(4,), dtype=float32)
instead of the correct normalized output ([0, 0, 0, 0]
).
The basic issue is that this calculation is incorrect: https://github.com/keras-team/keras/blob/efaaf85e19113400f23462cbafcef433cd95ad9c/keras/src/layers/normalization/batch_normalization.py#L310-L314 because it doesn't account for the broadcasting (i.e. it gives a value of 2 in the above example, when it should be 2 * 3 * 4).
See https://github.com/keras-team/keras/issues/19818 for more discussion/background.