flax
flax copied to clipboard
Flax avg_pool also includes padding tokens when computing the mean
I am not sure if this is an issue but it is definitely a cause of possible confusion: We currently implement pooling.avg_pool
as avg_pool(x) = lax.reduce_window(lax.add, x) / prod(window_size)
. If we use padding, we always divide by the full window size even if this contains padding tokens.
Example:
xs = np.array([1, 2]).reshape((1, 2, 1))
avg_pool(xs, window_shape=(2,), strides=(1,), padding="SAME").reshape((2,))
# Result: [1.5, 1. ]
Is this what we want? the first result (1+2)/2=1.5
makes sense, but the second result 2/2=1
. is a bit odd. Shouldn't we do 2/1=2
?
Other frameworks do it as follows:
- Pytorch also uses zero padding but they have a manual override option: https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
- TF does not use padding and does not provide an override option.
Personally I feel that including padding tokens with value 0 is wrong (it seems like an arbitrary constant). At the very least we should be explicit about our choice and document it.
A possible solution to implementing average pooling and only counting non-padding tokens is to doing an additional sum_pool2 on the same input shape with only 1s, where you pad with 0s. Then you return sum_pool / sum_pool2, which correctly ignores the padding tokens.
Discussed this offline with @jheek and @cgarciae. We agreed that the current behavior is not desirable since we are assuming that padding tokens for avg_pool are 0's and we include them when counting the average, but we are not docuementing this anywhere. Tensorflow has chosen to implement this differently, namely by excluding the padding tokens, and similarly, they are not documenting this in their APis. Pytorch seems to have the best of both worlds: they allow the user to specify it in a flags. This seems something we could do as well.
I'm going to close this since the changes in #2448 seem to address the main point, feel free to reopen.